{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ATlvvGTUto8T",
    "outputId": "48d62cee-4f18-4e75-913a-fcbc97366c04"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7b187c0e00d0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "torch.manual_seed(12046)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "UV7wHrEFto8V"
   },
   "outputs": [],
   "source": [
    "class LSTMCell(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        '''\n",
    "        长短期记忆网络的神经元\n",
    "        参数\n",
    "        ----\n",
    "        input_size ：int，输入数据的特征长度\n",
    "        hidden_size ：int，隐藏状态的特征长度\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        combined_size = self.input_size + self.hidden_size\n",
    "        # 定义输入门的线性部分\n",
    "        self.in_gate = nn.Linear(combined_size, self.hidden_size)\n",
    "        # 定义遗忘门的线性部分\n",
    "        self.forget_gate = nn.Linear(combined_size, self.hidden_size)\n",
    "        # 定义备选细胞状态的线性部分\n",
    "        self.new_cell_state = nn.Linear(combined_size, self.hidden_size)\n",
    "        # 定义输出门的线性部分\n",
    "        self.out_gate = nn.Linear(combined_size, self.hidden_size)\n",
    "\n",
    "    def forward(self, inputs, state=None):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        inputs ：torch.FloatTensor\n",
    "            输入数据，形状为(B, I)，其中B表示批量大小，I表示文字特征的长度（input_size）\n",
    "        state ：tuple(torch.FloatTensor, torch.FloatTensor)\n",
    "            (隐藏状态，细胞状态)，两个状态的形状都为(B, H)，其中H表示隐藏状态的长度（hidden_size）\n",
    "        返回\n",
    "        ----\n",
    "        hs ：torch.FloatTensor，隐藏状态，形状为(B, H)\n",
    "        cs ：torch.FloatTensor，细胞状态，形状为(B, H)\n",
    "        '''\n",
    "        B, _ = inputs.shape\n",
    "        if state is None:\n",
    "            state = self.init_state(B, inputs.device)\n",
    "        hs, cs = state\n",
    "        combined = torch.cat((inputs, hs), dim=1)           # (B, I + H)\n",
    "        # 输入门\n",
    "        ingate = F.sigmoid(self.in_gate(combined))          # (B,     H)\n",
    "        # 遗忘门\n",
    "        forgetgate = F.sigmoid(self.forget_gate(combined))  # (B,     H)\n",
    "        # 输出门\n",
    "        outgate = F.sigmoid(self.out_gate(combined))        # (B,     H)\n",
    "        # 更新细胞状态\n",
    "        ncs = F.tanh(self.new_cell_state(combined))         # (B,     H)\n",
    "        cs = (forgetgate * cs) + (ingate * ncs)             # (B,     H)\n",
    "        # 更新隐藏状态\n",
    "        hs = outgate * F.tanh(cs)                           # (B,     H)\n",
    "        return hs, cs\n",
    "\n",
    "    def init_state(self, B, device):\n",
    "        # 默认的隐藏状态和细胞状态全部都等于0\n",
    "        cs = torch.zeros((B, self.hidden_size), device=device)\n",
    "        hs = torch.zeros((B, self.hidden_size), device=device)\n",
    "        return hs, cs\n",
    "\n",
    "class LSTM(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        '''\n",
    "        单层的长短期记忆网络（支持批量计算）\n",
    "        参数\n",
    "        ----\n",
    "        input_size ：int，输入数据的特征长度\n",
    "        hidden_size ：int，隐藏状态的特征长度\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = LSTMCell(self.input_size, self.hidden_size)\n",
    "\n",
    "    def forward(self, inputs, state=None):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        inputs ：torch.FloatTensor\n",
    "            输入数据的集合，形状为(B, T, C)，其中B表示批量大小，T表示文本长度，C表示文字特征的长度（input_size）\n",
    "        state ：tuple(torch.FloatTensor, torch.FloatTensor)\n",
    "            (初始的隐藏状态，初始的细胞状态)，两个状态的形状都为(B, H)，其中H表示隐藏状态的长度（hidden_size）\n",
    "        返回\n",
    "        ----\n",
    "        hidden ：torch.FloatTensor，所有隐藏状态的集合，形状为(B, T, H)\n",
    "        '''\n",
    "        re = []\n",
    "        B, T, C = inputs.shape\n",
    "        inputs = inputs.transpose(0, 1)  # (T, B, C)\n",
    "        for i in range(T):\n",
    "            state = self.lstm(inputs[i], state)\n",
    "            # 只记录隐藏状态，state[0]的形状为(B, H)\n",
    "            re.append(state[0])\n",
    "        result_tensor = torch.stack(re, dim=0)  # (T, B, H)\n",
    "        return result_tensor.transpose(0, 1)    # (B, T, H)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aGr01cqBto8W",
    "outputId": "ce3ac010-0ac5-4dde-fb31-4543e7f6c500"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(True), (2, 17, 15, 16, 15))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test_lstm():\n",
    "    '''\n",
    "    测试LSTM实现的准确性\n",
    "    '''\n",
    "    # 随机生成模型结构\n",
    "    B, T, input_size, hidden_size, num_layers = torch.randint(1, 20, (5,)).tolist()\n",
    "    ref_model = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)\n",
    "    # 随机生成输入\n",
    "    inputs = torch.randn(B, T, input_size)\n",
    "    hs, cs = torch.randn((2 * num_layers, B, hidden_size)).chunk(2, 0)\n",
    "    _hs = list((i.squeeze(0) for i in hs))\n",
    "    _cs = list((i.squeeze(0) for i in cs))\n",
    "    re = inputs\n",
    "    # 取出模型参数\n",
    "    for layer_index in range(num_layers):\n",
    "        l = ref_model.all_weights[layer_index]\n",
    "        if layer_index == 0:\n",
    "            model = LSTM(input_size, hidden_size)\n",
    "        else:\n",
    "            model = LSTM(hidden_size, hidden_size)\n",
    "        i, f, c, o = torch.cat((l[0], l[1]), dim=1).chunk(4, 0)\n",
    "        ib, fb, cb, ob = (l[2] + l[3]).chunk(4, 0)\n",
    "        # 设置模型参数\n",
    "        model.lstm.in_gate.weight = nn.Parameter(i)\n",
    "        model.lstm.in_gate.bias = nn.Parameter(ib)\n",
    "        model.lstm.forget_gate.weight = nn.Parameter(f)\n",
    "        model.lstm.forget_gate.bias = nn.Parameter(fb)\n",
    "        model.lstm.new_cell_state.weight = nn.Parameter(c)\n",
    "        model.lstm.new_cell_state.bias = nn.Parameter(cb)\n",
    "        model.lstm.out_gate.weight = nn.Parameter(o)\n",
    "        model.lstm.out_gate.bias = nn.Parameter(ob)\n",
    "        # 计算隐藏状态\n",
    "        re = model(re, (_hs[layer_index], _cs[layer_index]))\n",
    "    ref_re, _ = ref_model(inputs, (hs, cs))\n",
    "    # 验证计算结果（最后一层的隐藏状态是否一致）\n",
    "    out = torch.all(torch.abs(re - ref_re) < 1e-4)\n",
    "    return out, (B, T, input_size, hidden_size, num_layers)\n",
    "\n",
    "test_lstm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "FDmaQtfbto8Z"
   },
   "outputs": [],
   "source": [
    "# 一些超参数\n",
    "learning_rate = 1e-3\n",
    "eval_iters = 10\n",
    "batch_size=1000\n",
    "sequence_len=64\n",
    "# 如果有GPU，该脚本将使用GPU进行计算\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "iKwo5iGnto8Z",
    "outputId": "33b6bdaa-b2b1-43eb-f318-c40ff76be96d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "98"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets = load_dataset(\"code_search_net\", \"python\")\n",
    "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n",
    "\n",
    "class char_tokenizer:\n",
    "\n",
    "    def __init__(self, data):\n",
    "        # 数据中出现的所有字符构成字典\n",
    "        chars = sorted(list(set(''.join(data))))\n",
    "        # 预留一个位置给结尾的特殊字符\n",
    "        self.char2ind = {s : i + 1 for i, s in enumerate(chars)}\n",
    "        self.char2ind['<|e|>'] = 0\n",
    "        self.ind2char = {i : s for s, i in self.char2ind.items()}\n",
    "\n",
    "    def encode(self, text):\n",
    "        return [self.char2ind[c] for c in text]\n",
    "\n",
    "    def decode(self, enc):\n",
    "        if isinstance(enc, int):\n",
    "            return self.ind2char[enc]\n",
    "        return [self.ind2char[i] for i in enc]\n",
    "\n",
    "tok = char_tokenizer(datasets['whole_func_string'])\n",
    "len(tok.char2ind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "GS3Cfz2wto8a"
   },
   "outputs": [],
   "source": [
    "class CharLSTM(nn.Module):\n",
    "\n",
    "    def __init__(self, vs):\n",
    "        '''\n",
    "        三层的长短期记忆网络\n",
    "        参数\n",
    "        ----\n",
    "        vs ：int，字典大小\n",
    "        '''\n",
    "        super().__init__()\n",
    "        # 定义文字嵌入的特征长度\n",
    "        self.emb_size = 256\n",
    "        # 定义隐藏状态的特征长度\n",
    "        self.hidden_size = 128\n",
    "        # 文字嵌入层\n",
    "        self.embedding = nn.Embedding(vs, self.emb_size)\n",
    "        # 随机失活\n",
    "        self.dp = nn.Dropout(0.4)\n",
    "        # 第一层长短期记忆网络\n",
    "        self.lstm1 = LSTM(self.emb_size, self.hidden_size)\n",
    "        # 层归一化\n",
    "        self.norm1 = nn.LayerNorm(self.hidden_size)\n",
    "        self.lstm2 = LSTM(self.hidden_size, self.hidden_size)\n",
    "        self.norm2 = nn.LayerNorm(self.hidden_size)\n",
    "        self.lstm3 = LSTM(self.hidden_size, self.hidden_size)\n",
    "        self.norm3 = nn.LayerNorm(self.hidden_size)\n",
    "        # 语言建模头，根据最后一层的隐藏状态预测下一个字母是什么\n",
    "        self.h2o = nn.Linear(self.hidden_size, vs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，当前字母在字典中的位置，形状为(B, T)\n",
    "        返回\n",
    "        ----\n",
    "        output ：torch.FloatTensor，预测结果的logits，形状为(B, T, vs)\n",
    "        '''\n",
    "        emb = self.embedding(x)                   # (B, T,  C)\n",
    "        h = self.norm1(self.dp(self.lstm1(emb)))  # (B, T,  H)\n",
    "        # 第一层的隐藏状态是第二层的输入\n",
    "        h = self.norm2(self.dp(self.lstm2(h)))    # (B, T,  H)\n",
    "        # 第二层的隐藏状态是第三层的输入\n",
    "        h = self.norm3(self.dp(self.lstm3(h)))    # (B, T,  H)\n",
    "        # 使用第三层的隐藏状态预测下一个字母是什么\n",
    "        output = self.h2o(h)                      # (B, T, vs)\n",
    "        return output\n",
    "\n",
    "model = CharLSTM(len(tok.char2ind)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "zc5jC4dxto8a",
    "outputId": "44079c5c-396e-490f-fe69-47d8cffafdd9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CharLSTM(\n",
       "  (embedding): Embedding(98, 256)\n",
       "  (dp): Dropout(p=0.4, inplace=False)\n",
       "  (lstm1): LSTM(\n",
       "    (lstm): LSTMCell(\n",
       "      (in_gate): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (forget_gate): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (new_cell_state): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (out_gate): Linear(in_features=384, out_features=128, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "  (lstm2): LSTM(\n",
       "    (lstm): LSTMCell(\n",
       "      (in_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (forget_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (new_cell_state): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (out_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "  (lstm3): LSTM(\n",
       "    (lstm): LSTMCell(\n",
       "      (in_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (forget_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (new_cell_state): Linear(in_features=256, out_features=128, bias=True)\n",
       "      (out_gate): Linear(in_features=256, out_features=128, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "  (h2o): Linear(in_features=128, out_features=98, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 展示模型结构\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "SfRdGvxoto8a"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def generate_batch(model, idx, max_new_tokens=300):\n",
    "    '''\n",
    "    利用模型生成文本（反复使用模型进行预测）\n",
    "    参数\n",
    "    ----\n",
    "    model ：CharLSTM，生成文本的模型\n",
    "    idx ：torch.LongTensor，当前字母在字典中的位置，形状为(1, T)\n",
    "    max_new_tokens ：int，生成文本的最大长度\n",
    "    返回\n",
    "    ----\n",
    "    out ：list[int]，生成的文本\n",
    "    '''\n",
    "    # 将模型切换至评估模式\n",
    "    model.eval()\n",
    "    for _ in range(max_new_tokens):\n",
    "        # 限制背景长度，使之与模型训练时的状况更相符\n",
    "        # 当然也可以不限制\n",
    "        context = idx[:, -sequence_len:]\n",
    "        # 在文本生成时，模型的计算效率很低，因为有很多重复计算\n",
    "        logits = model(context)\n",
    "        # 只使用最后一个预测结果\n",
    "        logits = logits[:, -1, :]\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        # 根据模型预测的概率，得到最终的预测结果（下一个字母）\n",
    "        # 这一步运算有一定随机性\n",
    "        ix = torch.multinomial(probs, num_samples=1)\n",
    "        idx = torch.cat((idx, ix), dim=1)\n",
    "        if ix.item() == 0:\n",
    "            break\n",
    "    # 将模型切换至训练模式\n",
    "    model.train()\n",
    "    return idx.tolist()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DfHreqdJto8a",
    "outputId": "b05db6d9-a1e0-4695-e814-702448958986"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def*$O(h/of(\"YP{so.8G|1w=3:1'ZS?z9)N[{3Q=CKfAM:iEca\";+Q31<sA..WS$M0Nx!qyT3jyMö54a)'W~]\\r/&B\"T{Y\n",
      "cdtM>SDax1zk<|e|>\n"
     ]
    }
   ],
   "source": [
    "# 使用模型来生成文本\n",
    "begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n",
    "print(''.join(tok.decode(generate_batch(model, begin_text))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ngIplGT3to8b",
    "outputId": "50fa082b-6a34-4a73-ab87-43d7bada48ee"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([605913, 64]), torch.Size([605913, 64]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def process(data, sequence_len=sequence_len):\n",
    "    '''\n",
    "    根据文本生成训练数据\n",
    "    '''\n",
    "    # text是字符串列表\n",
    "    text = data['whole_func_string']\n",
    "    inputs, labels = [], []\n",
    "    for i in text:\n",
    "        enc = tok.encode(i)\n",
    "        # 0对应着文本结束\n",
    "        enc += [0]\n",
    "        # 将文本转换为多个训练数据\n",
    "        for i in range(len(enc) - sequence_len):\n",
    "            inputs.append(enc[i: i + sequence_len])\n",
    "            # 预测标签是下一个字母，因此只需要挪动一个位置即可\n",
    "            labels.append(enc[i + 1: i + 1 + sequence_len])\n",
    "    return {'inputs': inputs, 'labels': labels}\n",
    "\n",
    "# 将数据分为训练集和测试集\n",
    "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n",
    "# 将文本转换为训练数据，里面包含inputs和labels\n",
    "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n",
    "tokenized.set_format(type='torch', device=device)\n",
    "\n",
    "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cFPy9_AWto8b",
    "outputId": "91eed388-c6a1-4872-c993-514d055da2a1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'inputs': tensor([[71, 80, 88,  ..., 43, 48, 40],\n",
       "         [82, 57, 75,  ..., 71, 78, 71],\n",
       "         [91,  2, 85,  ..., 85, 71, 86],\n",
       "         ...,\n",
       "         [71,  2, 54,  ..., 79, 71, 65],\n",
       "         [ 2,  2,  2,  ...,  1,  1,  2],\n",
       "         [84, 75, 80,  ..., 85, 86, 84]], device='cuda:0'),\n",
       " 'labels': tensor([[80, 88,  2,  ..., 48, 40, 49],\n",
       "         [57, 75, 86,  ..., 78, 71, 79],\n",
       "         [ 2, 85, 71,  ..., 71, 86, 10],\n",
       "         ...,\n",
       "         [ 2, 54, 91,  ..., 71, 65, 65],\n",
       "         [ 2,  2,  4,  ...,  1,  2,  2],\n",
       "         [75, 80, 73,  ..., 86, 84, 75]], device='cuda:0')}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 构建数据读取器\n",
    "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n",
    "# 获取一个批量的数据\n",
    "next(iter(test_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "oXR279ncto8b",
    "outputId": "bb83d85b-a4e9-44ee-b11d-1db0c0a2d471"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': 4.7519965171813965, 'test': 4.765100002288818}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def estimate_loss(model):\n",
    "    re = {}\n",
    "    # 将模型切换至评估模式\n",
    "    model.eval()\n",
    "    re['train'] = _loss(model, train_loader)\n",
    "    re['test'] = _loss(model, test_loader)\n",
    "    # 将模型切换至训练模式\n",
    "    model.train()\n",
    "    return re\n",
    "\n",
    "@torch.no_grad()\n",
    "def _loss(model, data_loader):\n",
    "    \"\"\"\n",
    "    计算模型在不同数据集下面的评估指标\n",
    "    \"\"\"\n",
    "    loss = []\n",
    "    data_iter= iter(data_loader)\n",
    "    # 随机使用多个批量数据来预估模型效果\n",
    "    for k in range(eval_iters):\n",
    "        data = next(data_iter, None)\n",
    "        if data is None:\n",
    "            data_iter = iter(data_loader)\n",
    "            data = next(data_iter, None)\n",
    "        inputs, labels = data['inputs'], data['labels']\n",
    "        logits = model(inputs)\n",
    "        # 根据cross_entropy的定义，需要对logits进行转置运算\n",
    "        # 具体细节请参考cross_entropy的官方文档\n",
    "        logits = logits.transpose(-2, -1)\n",
    "        loss.append(F.cross_entropy(logits, labels).item())\n",
    "    return torch.tensor(loss).mean().item()\n",
    "\n",
    "estimate_loss(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "5PpRhC4Oto8c"
   },
   "outputs": [],
   "source": [
    "def train_lstm(model, optimizer, data_loader, epochs=10):\n",
    "    lossi = []\n",
    "    for epoch in range(epochs):\n",
    "        for i, data in enumerate(data_loader, 0):\n",
    "            inputs, labels = data['inputs'], data['labels']\n",
    "            optimizer.zero_grad()\n",
    "            logits = model(inputs)\n",
    "            # 根据cross_entropy的定义，需要对logits进行转置运算\n",
    "            # 具体细节请参考cross_entropy的官方文档\n",
    "            logits = logits.transpose(-2, -1)\n",
    "            loss = F.cross_entropy(logits, labels)\n",
    "            lossi.append(loss.item())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        # 评估模型，并输出结果\n",
    "        stats = estimate_loss(model)\n",
    "        train_loss = f'train loss {stats[\"train\"]:.4f}'\n",
    "        test_loss = f'test loss {stats[\"test\"]:.4f}'\n",
    "        print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n",
    "    return lossi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "WUVdQEcAto8c",
    "outputId": "63073b27-622e-45e0-bfaf-7a2c3827b5fc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch  0: train loss 1.2563, test loss 1.4122\n",
      "epoch  1: train loss 1.1342, test loss 1.3129\n",
      "epoch  2: train loss 1.0388, test loss 1.2483\n",
      "epoch  3: train loss 0.9971, test loss 1.2172\n",
      "epoch  4: train loss 0.9649, test loss 1.2048\n",
      "epoch  5: train loss 0.9491, test loss 1.1944\n",
      "epoch  6: train loss 0.9319, test loss 1.1899\n",
      "epoch  7: train loss 0.9200, test loss 1.1925\n",
      "epoch  8: train loss 0.9045, test loss 1.1841\n",
      "epoch  9: train loss 0.8960, test loss 1.1883\n"
     ]
    }
   ],
   "source": [
    "l = train_lstm(model, optim.Adam(model.parameters(), lr=learning_rate), train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 448
    },
    "id": "SQ7oEi-Pto8c",
    "outputId": "2f987df6-3036-47c0-f8ed-139be38bf2aa"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7b17187054b0>]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5eUlEQVR4nO3de3iU9Z3//9dMkpkcZ0ISciIJBIMcTUTkELDiAUW0rux2rXX1h7bqfrXY1dXfdou7PWx7tfG31nXdfl2sa5V2LaW1FbQUQUQBUUBBohwUCaeEkAMQck4myczn90eSgUCCmZDMnWSej+uaS2bu+555z0dIXtfndNuMMUYAAAAWsVtdAAAACG2EEQAAYCnCCAAAsBRhBAAAWIowAgAALEUYAQAAliKMAAAASxFGAACApcKtLqA3fD6fjh8/rri4ONlsNqvLAQAAvWCMUV1dndLT02W399z/MSTCyPHjx5WZmWl1GQAAoA9KSkqUkZHR4/EhEUbi4uIktX8Zl8tlcTUAAKA3amtrlZmZ6f893pMhEUY6h2ZcLhdhBACAIebLplgwgRUAAFiKMAIAACxFGAEAAJYijAAAAEsRRgAAgKUIIwAAwFKEEQAAYCnCCAAAsBRhBAAAWIowAgAALEUYAQAAliKMAAAASw2JG+UNlF9tOaySqkZ9Y0amJqRyAz4AAKwQ0j0jqz89rmUfHFHxqUarSwEAIGSFdBixd9zS2GeMxZUAABC6QjqMhPnDiMWFAAAQwkI6jHRkEXlJIwAAWCakw0iYnWEaAACsFtJhhDkjAABYL7TDSGfPiM/iQgAACGGhHUY654zQMwIAgGVCOox0rqYxhBEAACwT0mHE1hFGvAzTAABgmZAOI2Ed354JrAAAWCekwwiraQAAsF5AYWTp0qXKzc2Vy+WSy+VSfn6+3nzzzR7PX7ZsmWw2W5dHZGTkRRfdX86spiGMAABglYDu2puRkaEnn3xS48aNkzFGv/71r3Xbbbdp165dmjx5crfXuFwu7d+/3/+8c57GYNDZM+IliwAAYJmAwsitt97a5flPf/pTLV26VNu2besxjNhsNqWmpva9wgEU1pGLWE0DAIB1+jxnxOv1asWKFWpoaFB+fn6P59XX12v06NHKzMzUbbfdpr179/b1I/udv2eEYRoAACwTUM+IJO3evVv5+flqbm5WbGysVq5cqUmTJnV77vjx4/XSSy8pNzdXNTU1+vnPf67Zs2dr7969ysjI6PEzPB6PPB6P/3ltbW2gZfaKf84IWQQAAMsE3DMyfvx4FRYWavv27XrooYd0zz33aN++fd2em5+fr0WLFunyyy/X3Llz9dprr2nkyJH65S9/ecHPKCgokNvt9j8yMzMDLbNXwlhNAwCA5QIOIw6HQzk5OZo2bZoKCgqUl5enZ599tlfXRkREaOrUqSoqKrrgeUuWLFFNTY3/UVJSEmiZvWLv3GeErhEAACxz0fuM+Hy+LkMqF+L1erV7926lpaVd8Dyn0+lfPtz5GAhnVtMQRgAAsEpAc0aWLFmiBQsWKCsrS3V1dVq+fLk2btyodevWSZIWLVqkUaNGqaCgQJL04x//WLNmzVJOTo6qq6v11FNP6ejRo7r//vv7/5v0wZlNzywuBACAEBZQGKmsrNSiRYtUVlYmt9ut3NxcrVu3TjfccIMkqbi4WHb7mc6W06dP64EHHlB5eblGjBihadOm6YMPPuhxwmuwhbHpGQAAlgsojPzqV7+64PGNGzd2ef7MM8/omWeeCbioYOncf40JrAAAWCek700TxpwRAAAsF9JhpHOfEbIIAADWCe0wwg6sAABYLsTDSPt/mTMCAIB1QjqMsJoGAADrhXQYsbHPCAAAlgvpMMJqGgAArBfSYaRzzoghjAAAYJnQDiN2VtMAAGC10A4jzBkBAMByIR1Gwjq+PatpAACwTkiHkTM9I4QRAACsQhiR5CWLAABgmRAPI+3/pWcEAADrhHQYYQdWAACsF9JhxMacEQAALBfSYSTMv8+IxYUAABDCQjqMsAMrAADWC/Ewwr1pAACwGmFE7MAKAICVQjqMsJoGAADrhXQYsbHPCAAAlgvpMBLGXXsBALBcSIeRzjkjdIwAAGAdwohYTQMAgJVCPIy0/5c5IwAAWCekwwiraQAAsF5IhxH2GQEAwHqhHUZYTQMAgOVCO4wwZwQAAMuFdBgJ8w/TEEYAALBKSIcRG3NGAACwXEiHEVbTAABgvZAOI8wZAQDAeqEdRuzswAoAgNVCO4x0zhnxWVwIAAAhLKTDCKtpAACwXkiHERtzRgAAsFxIh5Ew/w6sFhcCAEAIC+kw0jlnxNAzAgCAZUI6jIR1fHtW0wAAYJ2QDiP+HVjZ9AwAAMuEdBgJYzt4AAAsF9JhxM7SXgAALBfaYaRzzghdIwAAWCa0w4h/NY3FhQAAEMJCOoyEcW8aAAAsF9JhhB1YAQCwXkiHkbCzhmnY+AwAAGuEdBjpnDMisbwXAACrhHYYsZ8JI6yoAQDAGqEdRs5kEeaNAABgkZAOI2H2s4dpCCMAAFghpMPI2XNGGKYBAMAaAYWRpUuXKjc3Vy6XSy6XS/n5+XrzzTcveM2rr76qCRMmKDIyUpdddpnWrFlzUQX3J0fYma/f0uazsBIAAEJXQGEkIyNDTz75pHbu3KkdO3bouuuu02233aa9e/d2e/4HH3ygO++8U/fdd5927dqlhQsXauHChdqzZ0+/FH+x7HabHOHtTdBMGAEAwBI2c5EbbCQkJOipp57Sfffdd96xO+64Qw0NDVq9erX/tVmzZunyyy/X888/3+vPqK2tldvtVk1NjVwu18WUe57cH61TbXObNjw+V5eMjO3X9wYAIJT19vd3n+eMeL1erVixQg0NDcrPz+/2nK1bt2revHldXps/f762bt3a14/td1GOMElSU4vX4koAAAhN4YFesHv3buXn56u5uVmxsbFauXKlJk2a1O255eXlSklJ6fJaSkqKysvLL/gZHo9HHo/H/7y2tjbQMnstMqI9jHjaCCMAAFgh4J6R8ePHq7CwUNu3b9dDDz2ke+65R/v27evXogoKCuR2u/2PzMzMfn3/s0VFdPaMMGcEAAArBBxGHA6HcnJyNG3aNBUUFCgvL0/PPvtst+empqaqoqKiy2sVFRVKTU294GcsWbJENTU1/kdJSUmgZfaasyOMNLfSMwIAgBUuep8Rn8/XZUjlbPn5+dqwYUOX19avX9/jHJNOTqfTv3y48zFQoiI6V9MQRgAAsEJAc0aWLFmiBQsWKCsrS3V1dVq+fLk2btyodevWSZIWLVqkUaNGqaCgQJL0yCOPaO7cuXr66ad1yy23aMWKFdqxY4deeOGF/v8mfRQZwQRWAACsFFAYqays1KJFi1RWVia3263c3FytW7dON9xwgySpuLhYdvuZzpbZs2dr+fLl+td//Vc98cQTGjdunFatWqUpU6b077e4CJHhHcM07DMCAIAlAgojv/rVry54fOPGjee9dvvtt+v2228PqKhg6lza20zPCAAAlgjpe9NIUmTnnBEmsAIAYAnCSOdqGiawAgBgCcII+4wAAGCpkA8jUfSMAABgqZAPI/45I0xgBQDAEoQRekYAALAUYYRNzwAAsBRhxH9vGiawAgBghZAPI/679rLPCAAAlgj5MBLTsQNrY0ubxZUAABCaQj6MxEVGSJLqmgkjAABYgTAS2X57HsIIAADWIIx0hJF6T5u8PmNxNQAAhB7CSMcwjSTV0zsCAEDQhXwYcYTb/buw1ja3WlwNAAChJ+TDiHSmd4QwAgBA8BFGJLmYxAoAgGUII2J5LwAAViKM6OzlvQzTAAAQbIQRSa7OOSNNhBEAAIKNMCLJFcWcEQAArEIY0VlzRjyEEQAAgo0wIim642Z5DYQRAACCjjAiKcbRPkzT2OK1uBIAAEIPYURStJOeEQAArEIYET0jAABYiTCis+aMtNAzAgBAsBFGJMU4O3pGPPSMAAAQbIQR0TMCAICVCCM60zPSxJwRAACCjjAiekYAALASYURnVtM0t/rk9RmLqwEAILQQRnRmnxFJaqR3BACAoCKMSHKE2RVut0lirxEAAIKNMCLJZrMpivvTAABgCcJIB3ZhBQDAGoSRDtyfBgAAaxBGOtAzAgCANQgjHdhrBAAAaxBGOnB/GgAArEEY6UDPCAAA1iCMdGDOCAAA1iCMdGA1DQAA1iCMdKBnBAAAaxBGOtAzAgCANQgjHegZAQDAGoSRDlGspgEAwBKEkQ7+nhH2GQEAIKgIIx38c0boGQEAIKgIIx2YMwIAgDUIIx06d2BtpGcEAICgIox04N40AABYgzDSIeas1TTGGIurAQAgdBBGOkR39Iz4jNTc6rO4GgAAQkdAYaSgoEDTp09XXFyckpOTtXDhQu3fv/+C1yxbtkw2m63LIzIy8qKKHgjREWGy29r/XNfcam0xAACEkIDCyKZNm7R48WJt27ZN69evV2trq2688UY1NDRc8DqXy6WysjL/4+jRoxdV9ECw222K7egdqSWMAAAQNOGBnLx27douz5ctW6bk5GTt3LlTV199dY/X2Ww2paam9q3CIHJFRai2uU21zayoAQAgWC5qzkhNTY0kKSEh4YLn1dfXa/To0crMzNRtt92mvXv3XszHDpi4yAhJUh1hBACAoOlzGPH5fHr00Uc1Z84cTZkypcfzxo8fr5deekmvv/66XnnlFfl8Ps2ePVvHjh3r8RqPx6Pa2touj2BwRXYM0zQxTAMAQLAENExztsWLF2vPnj3asmXLBc/Lz89Xfn6+//ns2bM1ceJE/fKXv9RPfvKTbq8pKCjQv/3bv/W1tD6jZwQAgODrU8/Iww8/rNWrV+vdd99VRkZGQNdGRERo6tSpKioq6vGcJUuWqKamxv8oKSnpS5kB8/eMMIEVAICgCahnxBij73znO1q5cqU2btyo7OzsgD/Q6/Vq9+7duvnmm3s8x+l0yul0BvzeF8sV1dkzQhgBACBYAgojixcv1vLly/X6668rLi5O5eXlkiS3262oqChJ0qJFizRq1CgVFBRIkn784x9r1qxZysnJUXV1tZ566ikdPXpU999/fz9/lYsX19EzwjANAADBE1AYWbp0qSTpmmuu6fL6yy+/rHvvvVeSVFxcLLv9zOjP6dOn9cADD6i8vFwjRozQtGnT9MEHH2jSpEkXV/kAcHXMGWECKwAAwRPwMM2X2bhxY5fnzzzzjJ555pmAirIKPSMAAAQf96Y5S+dqGiawAgAQPISRs7ii6BkBACDYCCNniWPOCAAAQUcYOYuLOSMAAAQdYeQsnT0j9S1t8vm+fLIuAAC4eISRs3SupjFGqvPQOwIAQDAQRs4SGREmR3h7k7ALKwAAwUEYOceZO/fSMwIAQDAQRs7hiuT+NAAABBNh5BzswgoAQHARRs7ReededmEFACA4CCPnoGcEAIDgIoycgzv3AgAQXISRc7g7hmmqCSMAAAQFYeQcSbFOSdLJeo/FlQAAEBoII+dIinNIIowAABAshJFzdPaMnKgjjAAAEAyEkXOMjOscpmmxuBIAAEIDYeQcnT0jpxtb1Or1WVwNAADDH2HkHCOiHbLb2u/cW9VA7wgAAAONMHKOMLtNicwbAQAgaAgj3WB5LwAAwUMY6UZSbPvyXnpGAAAYeISRbrCiBgCA4CGMdGMkwzQAAAQNYaQbbHwGAEDwEEa6cWaYhjACAMBAI4x0g54RAACChzDSDW6WBwBA8BBGujHSvyV8K1vCAwAwwAgj3RgR7VCY3SZJOsXyXgAABhRhpBt2u00JMQzVAAAQDISRHnQO1ZwgjAAAMKAIIz1IimNFDQAAwUAY6UHn/WkYpgEAYGARRnrg3/isjgmsAAAMJMJID5gzAgBAcBBGenBmF9ZmiysBAGB4I4z0INUdKUkqqyGMAAAwkAgjPRgVHyVJKqtuls9nLK4GAIDhizDSg1R3pOw2qcXrY94IAAADiDDSg4gwu1Jd7UM1pdVNFlcDAMDwRRi5gFEj2odqSk8TRgAAGCiEkQvonDdCzwgAAAOHMHIB9IwAADDwCCMXMCo+WhI9IwAADCTCyAWkx3dMYKVnBACAAUMYuYCMEWfmjBjDXiMAAAwEwsgFpHdMYK33tKm2qc3iagAAGJ4IIxcQ7QhXQoxDknSsutHiagAAGJ4II18is2OopqSKeSMAAAwEwsiXyEqMkSQVVzVYXAkAAMMTYeRLjElsX9575BTDNAAADATCyJcY3dEzcvQUPSMAAAwEwsiX8PeMnKRnBACAgRBQGCkoKND06dMVFxen5ORkLVy4UPv37//S61599VVNmDBBkZGRuuyyy7RmzZo+FxxsnT0jx2ua5GnzWlwNAADDT0BhZNOmTVq8eLG2bdum9evXq7W1VTfeeKMaGnoewvjggw9055136r777tOuXbu0cOFCLVy4UHv27Lno4oMhKdahGEeYjJGOsRMrAAD9zmYuYmvREydOKDk5WZs2bdLVV1/d7Tl33HGHGhoatHr1av9rs2bN0uWXX67nn3++V59TW1srt9utmpoauVyuvpbbZwuefU+fldXqpXuv1HUTUoL++QAADEW9/f19UXNGampqJEkJCQk9nrN161bNmzevy2vz58/X1q1be7zG4/Gotra2y8NKzBsBAGDg9DmM+Hw+Pfroo5ozZ46mTJnS43nl5eVKSenam5CSkqLy8vIerykoKJDb7fY/MjMz+1pmv2BFDQAAA6fPYWTx4sXas2ePVqxY0Z/1SJKWLFmimpoa/6OkpKTfPyMQnT0jh9lrBACAfhfel4sefvhhrV69Wps3b1ZGRsYFz01NTVVFRUWX1yoqKpSamtrjNU6nU06nsy+lDYic5FhJ0hfldRZXAgDA8BNQz4gxRg8//LBWrlypd955R9nZ2V96TX5+vjZs2NDltfXr1ys/Pz+wSi00Ia190k15bbNON7RYXA0AAMNLQGFk8eLFeuWVV7R8+XLFxcWpvLxc5eXlamo6s+R10aJFWrJkif/5I488orVr1+rpp5/W559/rh/96EfasWOHHn744f77FgMs1hmurIT2oZrPyq2dTAsAwHATUBhZunSpampqdM011ygtLc3/+P3vf+8/p7i4WGVlZf7ns2fP1vLly/XCCy8oLy9Pf/zjH7Vq1aoLTnodjCakxkmSPitjqAYAgP4U0JyR3mxJsnHjxvNeu/3223X77bcH8lGDzoTUOL21r0JFlYQRAAD6E/em6aWclPaekQMV9RZXAgDA8EIY6aVxHStqDlTW96qHCAAA9A5hpJeyk2Jkt0k1Ta06Ue+xuhwAAIYNwkgvRUaE+XdiLWKoBgCAfkMYCUDOWUM1AACgfxBGAnBm3ggragAA6C+EkQCMS+kIIwzTAADQbwgjARiX3LG8lxU1AAD0G8JIAHKSY+UIt6uqoUUHTzRYXQ4AAMMCYSQAkRFhunL0CEnSBwdPWlwNAADDA2EkQHNykiRJWw+esrgSAACGB8JIgKZ19Ix8eqzG4koAABgeCCMBmpTukiSVVjepurHF4moAABj6CCMBckVGKCshWpK073itxdUAADD0EUb6YHJH78gnDNUAAHDRCCN9MCM7QZL0fhEragAAuFiEkT6Ye+lISdKHh6vU2NJmcTUAAAxthJE+yE6KUcaIKLV4fdp+uMrqcgAAGNIII31gs9k0+5JESdK2Q+w3AgDAxSCM9FF+Zxhh8zMAAC4KYaSPZo1tDyO7S2tU29xqcTUAAAxdhJE+SnNHaUxitHxG+oh5IwAA9Blh5CJ0DtVwnxoAAPqOMHIRZl/SftO8d/ZXyhhjcTUAAAxNhJGLcM34kXKE2XXoRIMOVNZbXQ4AAEMSYeQixEVG6Cvj2ntH/vTxMYurAQBgaCKMXKQ7Z2RJkn63vVgNHnZjBQAgUISRi3TdhGSNSYxWbXMbvSMAAPQBYeQi2e02fXNOtiTp1x8csbYYAACGIMJIP/ibK0YpzG7TwRMNKq1usrocAACGFMJIP4iLjNDkdJckNkADACBQhJF+MmNMgiQ2QAMAIFCEkX5y3YRkSdLrn5TqRJ3H4moAABg6CCP9JP+SROVlxqu51affbj9qdTkAAAwZhJF+YrPZdO/s0ZLaN0Dz+dgeHgCA3iCM9KObJqcp1hmukqom/X5HidXlAAAwJBBG+lGUI0z/cH2OJOn/W/u5Wtp8FlcEAMDgRxjpZ/ddNVZJsU5VN7bqg4MnrS4HAIBBjzDSz8LsNt00JUWStGZ3mcXVAAAw+BFGBsDNl6VJkt7aV6FWL0M1AABcCGFkAMzMTlRijEPVja3aUsRQDQAAF0IYGQBhdptuzUuXJP3qvcMWVwMAwOBGGBkg912VrXC7TVuKTmrj/kqrywEAYNAijAyQzIRo3TN7jCTpR2/sZRM0AAB6QBgZQI/dcKninOE6cqpR2w5zAz0AALpDGBlAMc5wfTWvfWXNS1uOWFsMAACDFGFkgH1rTrbC7Da9/VkFm6ABANANwsgAG5cSp29Mz5QkfX/VHtU1t1pcEQAAgwthJAjumtl+N9+DJxr09V9uYzIrAABnIYwEwaR0l3/fkc/KarX8w2KLKwIAYPAgjATJL+6cqvuvypYk/euqPXrn8wqLKwIAYHAgjATRYzdeqlljEyRJ//n2AYZrAAAQYSSooh3h+r9/d4UiI+z69FiNfrbmM6tLAgDAcoSRIEuKderJv8mVJP3q/cPad7zW4ooAALBWwGFk8+bNuvXWW5Weni6bzaZVq1Zd8PyNGzfKZrOd9ygvL+9rzUPewqmj9NXcNBkjLd100OpyAACwVMBhpKGhQXl5eXruuecCum7//v0qKyvzP5KTkwP96GHlga+MlST9+ZPjeu3jYxZXAwCAdcIDvWDBggVasGBBwB+UnJys+Pj4gK8brnIz3BqdGK2jpxr12B8+UenpJv393LFyhodZXRoAAEEVtDkjl19+udLS0nTDDTfo/fffv+C5Ho9HtbW1XR7Djc1m009um+J//vT6L/TEa3ssrAgAAGsMeBhJS0vT888/rz/96U/605/+pMzMTF1zzTX6+OOPe7ymoKBAbrfb/8jMzBzoMi1x9aUj9fZjc/3PX9t1TLuKT1tYEQAAwWczxvR5swubzaaVK1dq4cKFAV03d+5cZWVl6X//93+7Pe7xeOTxePzPa2trlZmZqZqaGrlcrr6WO2g1eNr0+B8+0dq95UqIcWjzd69VrDPgETQAAAaV2tpaud3uL/39bcnS3hkzZqioqKjH406nUy6Xq8tjOItxhuvnX89TVkK0qhpaVLDmMz33bpGaW71WlwYAwICzJIwUFhYqLS3Nio8etGKd4bqj4+6+v91erKfW7dcz67+wuCoAAAZewGMB9fX1XXo1Dh8+rMLCQiUkJCgrK0tLlixRaWmpfvOb30iS/vM//1PZ2dmaPHmympub9eKLL+qdd97RW2+91X/fYpj4+pWZemXbUZXVNEuSfrn5kKZmjdD0MSOUGOu0uDoAAAZGwGFkx44duvbaa/3PH3vsMUnSPffco2XLlqmsrEzFxWfuStvS0qLHH39cpaWlio6OVm5urt5+++0u74F2I+OcWvvI1dp04IT+a8MBFVXW68FXdsoRZtefv3OVxqfGWV0iAAD97qImsAZLbyfADCcvbD6on6353P/829dcou/eNMHCigAACMygnsCKL/f1KzN1VU6SJqe3/89bu7dcnjYmtAIAhh/CyCAVH+3QK/fP1PIHZikywq5DJxr09V9uU1FlnVrafFaXBwBAvyGMDHLuqAg9f/c0uaMi9ElJteb9x2bd+Mwm1TW3Wl0aAAD9gjAyBFwzPlmrv3OVpmbFS5KOnGrUt3/7saoaWqwtDACAfkAYGSIyE6K18ttz9Pzd02S3Se8dOKl7X/5QZTVNVpcGAMBFIYwMMTdNSdXvHpilqIgwfXqsRrf+YotqmhiyAQAMXYSRIWjm2ES9/M3pioyw62R9i/L+7S1d9/ON+vm6/cwlAQAMOewzMoSt2V2mb//2/Lsf3zQ5VXdMz9S1E5ItqAoAgHa9/f1NGBniPjpSpbrmVpWebtL3X9/b5diWf75WGSOiLaoMABDqCCMhqLKuWTN+usH/PD46Qo/fcKnunjVaNpvNwsoAAKGIHVhDUHJcpH7w1UlKcbXfVK+6sVXff32vXtl2VO/ur5TPN+hzJwAgBNEzMkyt31ehB36zo8tr984eoxsmpSjNHamxI2MtqgwAECoYpoGOVzfpxmc2q97T1uV1u0169htTdWteukWVAQBCAWEEkqTK2mbJJi3deFC/2XpU3o6hGptNuubSkUpxReqfb5qgETEOiysFAAw3hBGcx9PmlTHSgmff0+GTDV2O/d3MLM0am6hbLktTmJ3JrgCAi0cYQY+qG1v01t4KfXKsWr/dXtzl2PQxIzQjO0FfvzJToxNjLKoQADAcEEbQKwVvfqZXdxw776Z7CTEOXT8hWdeMT9YtuWkWVQcAGMoIIwiIMUY//ctnevuzCh051eh/3RFm1+bvXit3VISiHGEWVggAGGoII+izTV+c0I//vFcHT3SdV5IU69D1E1K05OYJinGGq6nVK1dkhEVVAgAGO8IILtpHR6r0rZc/Ut05S4MTYxyy222qaWrVH/5Pvi7PjLemQADAoEYYQb9obvXqRJ1H//edIv1+R0m359w0OVW5mW59a062IsLsrMYBAEgijKCfedq8WrunXK7ICG07dErJrkj9ZPW+LueE220Ks9t075wxevDqS9i7BABCHGEEA271p8e1+pMyrd1b3u3xrIRojR0ZowZPm+ZPTtWCy9I0Kj4qyFUCAKxCGEFQvV5Yqr3Ha9XS5tO7+yt19KwVOWdbtXiOLs+Ml9dntPXgKeVluhUZESa7zcbwDgAMM4QRWOoff1+olbtKuz12RVa84qMdeufzSkU7wtTmNZqU7tKrD+YrIowbSQPAcEEYgaWaW73aevCUrhqXpIgwu97cXaaHfvvxBa/5p/njVdPUqrf3Veie2WN018ws2W022ekxAYAhiTCCQcUYo//eeFAvbD6kmqZW3TQ5VZdluLVqV6kOVNb3eF3+2ET99v6ZBBIAGIIIIxi0mlu9iow4s5urMUbff32PXtlW3O35N1+WqjhnhD6vqJPPZzQ1K143TErRrLGJDOsAwCBGGMGQ4vMZ/XTNZzp6qlE/vz1X33hhmz4vr/vS656/e5pKq5s0MztBxVWN2nn0tO6ckamc5LggVA0AuBDCCIa05lavPG0+rdtbrs1fnFB2UowcYXY9vf6LXl2flxmv/LGJum5Csi4b5VaUI0x1za2y2WyKcYTJZmPYBwAGGmEEw9IbnxzXwcp61Xva9Ksth7scc0dFqKap9bxrYp3huionqct+KF/NTdODcy/RlFHuAa8ZAEIVYQTD3nsHTqisullfuTRJn5fX6epxI2WM0aelNTp6qkFr95SrsKRaFbWebq8Pt9v0tSsyVN3UouKqJl2aEqurcpI0OyfJvzmb12fY/wQA+ogwAqh9Lsp7RSe17P3Denf/iV5fFxURpjFJMTp0ol53TM/Ut6/JUWSEXdsOVenIqQbdcWWmDp9q0NTMeIZ8AKAHhBGgG0WVdUpzRykyIkzPbjigP+4oUU5KnGqbWlVe06zy2uaA3i83w61/uXmiYpzhqm5s1VXjkgaocgAYeggjQB+crPfoFxsOaESMQz4jfVB0Up+W1qilzder62+clKIjpxpUXtOsqy8dqYlpLt09a7T2Ha/V2JExGhHtUJidre8BhAbCCNBPGlva1Oo1qmpo0al6jzITonWwsl4/+ctn+qysNuD3i4ywa97EFM3JSdKteek6dKJeSbFOlVY3aXK6S9GO8AH4FgAQfIQRYIC1tPlks0mHTjRof0WddhWf1mdltdp+uErGtIeO5tb2HhW7TfL14l+aOypCC6akqrHFq9LqJj3z9cuVlRjtP15S1aiT9R4lxTr150+P6+5Zo+WKjBiorwgAF4UwAliszevT3z6/VacbW/Tq/8nXkVONWrz8Y52oa1/dY7NJX/avL9xu08g4p1q9Rq1en2qbW7tck5cZr6mZ8Sosqdbds0ZrQmqcJqa5ej0MVFnbrMKSat0wKYWJuAD6HWEEGCR8PuO/t87phhb9ZutRfWNGpo6cbNDyD4t1/1VjdWlqrMLtdq3cVaqCNZ9pZJyzVzvQdic7KUbXTUhWjDNcm/ZX6lRDi67KSVJuRrz2ldVoze5yPTh3rDZ9cULvF52SJP3NFaP0vZsmyGuM0txR/fbdAYQ2wggwRBljZLPZ5PMZFZ2oV0ubT02tXr39WYWmZY3QzqOn1djiVWZClFbuOq6IMJs+PVajMLtNEWE2/9BQX900OVXfvvYS/c97h9XgadPkdJd2l9boxkmpumN6ZpdeF0+bV44wO70qALpFGAFCSKvXpzavUZvPpzf3lOv1wlL5fNKoEVGanO5SSVWTik7UKzHGoZW7Si/qsyLCbLLJphZve+i5ZGSMpmaNUKvXp9unZeqTY9VyhttV3diqMUkx8hmj+KgIzc5JUqyzfXLugYo6najzaHYOS6GB4YwwAqBbWw+e0msfH9P/kz9aMc5wZSVE69Udx/TEyt36am6aWtp8emtfxXnXRUWEqanV2+fPHRUfpe9cl6Oiynq92LGV/9iR7fcccobbNScnSfddla3GFq82fnFC8yenKDkuUq1enyrrPCo+1aj/2nBAj914qaaPSehzHQCChzACICBVDS1KiHHIGKMDlfVqbvUqKdap331YrLtnjZY7KkJHTjXoidd2a395nW7JTdNf5Y3SO59Xat3ecpVWN3V5vxnZCWpsaVNCjFNflNcFvKFcYoxD08ckdLmnUKdbctNkt9l07+zR8rT5NDYpViNiIjpWMbXfFLGuuU27iqt1aUqsSqub9Obuco1JitE3pmcqLjJcnjafYpyBLaPuHEID0DuEEQBBdehEvcLsNiXGOlVV39JlSXK9p03/vvZzbSk6qUMnGiRJfzN1lJpavbLZpC8q6lVUWX9Rn9+5fHpimkuHT9ZfcO5MjCNMDS1epbsjNW1MgmZkJ8gdFaFZYxNU3diqNq/RJ8eqNToxWjkjY+U1Ri++d1ivfXxMDR6vFlyWqidunqgUV6Sk9rtMO8Pb584cO92oPaU1mj85leCCkEcYATAoeX1Gx6ublJkQ3eX15o4hIJ8xamrx6j/fPqD/3XZUkjQhNU7/cP04VdY2a8VHJcrLiFd1U4vW7T1/OOlCOkNIf+p8z8yEKM0Yk6g/fXxMkjQuOVb//re5ujwzXsZIrT6fTje06t39lUpxOTX30mT/ZGCfz6i5zevf8O50Q4vioyNks9lUUduspFgnu/ZiSCKMABjyjDE6eKJe2Umx5/0y9vqMnnu3SG0+o+smJCs5zimp/W7OkRFhSnVF6u3PKrRy13EZY/TQNZfoq7npemrdfo0dGaNUV/t8lG2HTunQyQZ5fUZ7j5/ZUdcdFaH46AgdO90kr88oKdahk/UtAX+HWGe46j1t570+b2KKrpuQrM1fnNDWQ6dU09R63jlxznDVedqUMSJKl41yKzMhWl6f0czsBL38/hElxjpUXtOsqVnx+qf5E+QIt+t0Q4u2FJ2UOypCY0fGKN0dpaIT9Xpq3X49OHespo1mvg2ChzACAAE6Ve/R6cZWvfHJcX01N02XpsSpqqFFh07Ua9roEfL6jE43tiouMly//6hEa/eU66t5acofm6hDJxq04fNKrdld1m2w6DQyzqnTDS1q682WvAGYkZ2gFFek1u4pU6v3zHt/c84YbTtU5b91wcQ0ly7PjNfWgyd1/cQUvV90UtdPTNbENJemj0nQyXqPRkQ7VFrdpEMn6hUf7dCkNJc8bV41eLwanxqnv3xapmSXU18ZN1LlNc0KD7OpstajSek9/3w2xqii1qPmVq/GJMUE9N0aW9pU1dCijBHRX34yBhXCCABYxOsz8vqM3vjkuEqqGnW8uklVDS2Kj3boX26ZqM/La/XoikL5jHT9hGSluiOVl+nW7z4s0f7yOo2IcajR06YDlfUaFR+lSekure9mhdOssQn+2w+czR0VccFA1F/+Ki9db3xy3P88Y0SU/nrqKF0zfqSOVzfr+6/v0fiUOGUmROuPO4/5z8tJjlVijEM/WThFn5fXaX95rSakutTY0qaIMLvyMuP1p53HlJcZrytHj9B3frdLHx6u0tNfz9PRU4366EiVfnHnVEU5wlTX3H6NO+r82yJUNbSopKpRuRlu5u9YhDACAINYq9enMJvNvztvd0qqGhXlCFNSrFPNrV6t+LBY8yalyOszSohxKK7jvkRFlXVavr1EX1TUKf+SRN07e4yeWrdfyz44Iqm9N+S7N41Xc4tX3399z3nDTeF2m1LdkTp2uuncEvyiIsLU5vN16XUZTK7IitePb5uiKEeYRsVHaduhU3rwlZ1qbvXphkkpstukRfljNGtsorYePKV6T3tYO3a6SSOiHXJG2PXh4SpNSXfrb64Ypbf2VWjf8Vo9eM0l/v1xpPb5PcVVjdp7vFY2m3TdhGRFRoTJGCNj1OX/Z3lNsxJjHYoIs3eZ5BxKCCMAEOI8bV6V1zQrKyHa/0uwpc2no6calJMcq6ZWryLDw+QzRnabTRV1zXJHRei1j0s1MS1Ob+2rkDM8TPljE3XF6Hht/uKkfvHOAZXXNKuy4x5LSbHOjl+y0viUOIWH2bRx/wl52nzKy3ArLjJCZTVNum5CssalxOnjo6fV1OrV64XtPSopLqeuyhmpohP18vmMdpfWWNZe3UmMcejKMSO0/XCVshKiVdXQcl5omzLKpT2ltYp2hGn2JYlq8HjlM0bbD1dpbFKMrr50pJZvL9b41DhdPzFZHx2pkqfVp8sy3Lrt8lF6c0+Z7skfozR3pKoaWpQY69TvPyrWsg+O6u9mZMrT5tNNU1J7PUxV19yqkqom/7DZqXqP/t9XP9Etuen622kZqm1uVZwzPCjBiDACABgwR081KNUdKWd42HnHfD6jVp+v22Od6j1tOlnnUcaIKIWH2f2vVzW06PDJBuVmuLXt0CkteulDjU+J0z/ecKk8bT6dbmjRtkOn5I6K0P1fydbu0ho98dqebjfkc4bb5Wnrfol3XoZbklTnadOp+hbVNLUq3G7r9VyeUfFRKq9tlref5v6cvangiOgInW48f5gt1RWpMUnROlnfvidQXXObZl+SqOKqRn189LQc4XaNS4nTntIaVTW06IZJKWpq8WpL0Un/e/zktsn68ep9avUaLbw8XfHRDn16rP1Gm389dVS/BxTCCABgyDtysj30REb0HGx8PqM2n9Hu0molx0WqpqlV7qgIJcQ49OSbn2vP8Rr96y0T9fZnlRoVH6WZ2QkalxLX5T2aW9vvs7Sz+LQOnajX3EuTtbu0RruKT+vj4tMqqqzXvIkpGp8ap1ty05QcF6kDFXXafrhK1Y0t2vB5pWZfkqiEGKcaPG1auvGgIsJsmjk2UXtKa1RWE9imf52SYp06We/p07WB+vev5err0zP79T0JIwAAWKTB0yYjdZlvcrqhRfWeNkU7wpQY69QHRSdVWt2kGyel6vPyWq3dW65rxiertc2nLUUn9Xczs3RpSpyMMdpXVqvffVgsV2SEpmcn6FR9iw5U1PlvrfDj2yZr/b4KNbV49dXcNOUkx2nN7jKdbmzR6k/LutSWkxyr8SlxcobbVdXYorLqZtls0qrFcy4Y+vqCMAIAwDB3oKJOLV6fJqe7ezync7J04bFqHaio09euyOgyNCZJbV7fea/1h97+/g74kzdv3qxbb71V6enpstlsWrVq1Zdes3HjRl1xxRVyOp3KycnRsmXLAv1YAABwjnEpcRcMIpIUEWaX3W7TFVkjdMf0rG5Dx0AEkUAE/OkNDQ3Ky8vTc88916vzDx8+rFtuuUXXXnutCgsL9eijj+r+++/XunXrAi4WAAAMP4HdslLSggULtGDBgl6f//zzzys7O1tPP/20JGnixInasmWLnnnmGc2fPz/QjwcAAMPMgPfLbN26VfPmzevy2vz587V169Yer/F4PKqtre3yAAAAw9OAh5Hy8nKlpKR0eS0lJUW1tbVqaup+t7+CggK53W7/IzOzf5caAQCAwcPaGSs9WLJkiWpqavyPkpISq0sCAAADJOA5I4FKTU1VRUXXGzxVVFTI5XIpKiqq22ucTqecTudAlwYAAAaBAe8Zyc/P14YNG7q8tn79euXn5w/0RwMAgCEg4DBSX1+vwsJCFRYWSmpfultYWKji4mJJ7UMsixYt8p//4IMP6tChQ/rud7+rzz//XP/93/+tP/zhD/rHf/zH/vkGAABgSAs4jOzYsUNTp07V1KlTJUmPPfaYpk6dqh/84AeSpLKyMn8wkaTs7Gz95S9/0fr165WXl6enn35aL774Ist6AQCAJLaDBwAAA2TAtoMHAADoT4QRAABgKcIIAACw1IDvM9IfOqe1sC08AABDR+fv7S+bnjokwkhdXZ0ksS08AABDUF1dndxud4/Hh8RqGp/Pp+PHjysuLk42m63f3re2tlaZmZkqKSlhlc6XoK16h3bqPdqqd2in3qOteieY7WSMUV1dndLT02W39zwzZEj0jNjtdmVkZAzY+7tcLv7i9hJt1Tu0U+/RVr1DO/UebdU7wWqnC/WIdGICKwAAsBRhBAAAWCqkw4jT6dQPf/hD7hDcC7RV79BOvUdb9Q7t1Hu0Ve8MxnYaEhNYAQDA8BXSPSMAAMB6hBEAAGApwggAALAUYQQAAFgqpMPIc889pzFjxigyMlIzZ87Uhx9+aHVJQbV582bdeuutSk9Pl81m06pVq7ocN8boBz/4gdLS0hQVFaV58+bpwIEDXc6pqqrSXXfdJZfLpfj4eN13332qr68P4rcYeAUFBZo+fbri4uKUnJyshQsXav/+/V3OaW5u1uLFi5WYmKjY2Fh97WtfU0VFRZdziouLdcsttyg6OlrJycn6p3/6J7W1tQXzqwy4pUuXKjc317+ZUn5+vt58803/cdqpe08++aRsNpseffRR/2u0Vbsf/ehHstlsXR4TJkzwH6edzigtLdXdd9+txMRERUVF6bLLLtOOHTv8xwf1z3QTolasWGEcDod56aWXzN69e80DDzxg4uPjTUVFhdWlBc2aNWvMv/zLv5jXXnvNSDIrV67scvzJJ580brfbrFq1ynzyySfmr/7qr0x2drZpamryn3PTTTeZvLw8s23bNvPee++ZnJwcc+eddwb5mwys+fPnm5dfftns2bPHFBYWmptvvtlkZWWZ+vp6/zkPPvigyczMNBs2bDA7duwws2bNMrNnz/Yfb2trM1OmTDHz5s0zu3btMmvWrDFJSUlmyZIlVnylAfPGG2+Yv/zlL+aLL74w+/fvN0888YSJiIgwe/bsMcbQTt358MMPzZgxY0xubq555JFH/K/TVu1++MMfmsmTJ5uysjL/48SJE/7jtFO7qqoqM3r0aHPvvfea7du3m0OHDpl169aZoqIi/zmD+Wd6yIaRGTNmmMWLF/ufe71ek56ebgoKCiysyjrnhhGfz2dSU1PNU0895X+turraOJ1O87vf/c4YY8y+ffuMJPPRRx/5z3nzzTeNzWYzpaWlQas92CorK40ks2nTJmNMe7tERESYV1991X/OZ599ZiSZrVu3GmPag5/dbjfl5eX+c5YuXWpcLpfxeDzB/QJBNmLECPPiiy/STt2oq6sz48aNM+vXrzdz5871hxHa6owf/vCHJi8vr9tjtNMZ//zP/2yuuuqqHo8P9p/pITlM09LSop07d2revHn+1+x2u+bNm6etW7daWNngcfjwYZWXl3dpI7fbrZkzZ/rbaOvWrYqPj9eVV17pP2fevHmy2+3avn170GsOlpqaGklSQkKCJGnnzp1qbW3t0lYTJkxQVlZWl7a67LLLlJKS4j9n/vz5qq2t1d69e4NYffB4vV6tWLFCDQ0Nys/Pp526sXjxYt1yyy1d2kTi79S5Dhw4oPT0dI0dO1Z33XWXiouLJdFOZ3vjjTd05ZVX6vbbb1dycrKmTp2q//mf//EfH+w/00MyjJw8eVJer7fLX05JSklJUXl5uUVVDS6d7XChNiovL1dycnKX4+Hh4UpISBi27ejz+fToo49qzpw5mjJliqT2dnA4HIqPj+9y7rlt1V1bdh4bTnbv3q3Y2Fg5nU49+OCDWrlypSZNmkQ7nWPFihX6+OOPVVBQcN4x2uqMmTNnatmyZVq7dq2WLl2qw4cP6ytf+Yrq6upop7McOnRIS5cu1bhx47Ru3To99NBD+od/+Af9+te/ljT4f6YPibv2AoPF4sWLtWfPHm3ZssXqUgat8ePHq7CwUDU1NfrjH/+oe+65R5s2bbK6rEGlpKREjzzyiNavX6/IyEiryxnUFixY4P9zbm6uZs6cqdGjR+sPf/iDoqKiLKxscPH5fLryyiv1s5/9TJI0depU7dmzR88//7zuuecei6v7ciHZM5KUlKSwsLDzZlxXVFQoNTXVoqoGl852uFAbpaamqrKyssvxtrY2VVVVDct2fPjhh7V69Wq9++67ysjI8L+empqqlpYWVVdXdzn/3Lbqri07jw0nDodDOTk5mjZtmgoKCpSXl6dnn32WdjrLzp07VVlZqSuuuELh4eEKDw/Xpk2b9F//9V8KDw9XSkoKbdWD+Ph4XXrppSoqKuLv1FnS0tI0adKkLq9NnDjRP6Q12H+mh2QYcTgcmjZtmjZs2OB/zefzacOGDcrPz7ewssEjOztbqampXdqotrZW27dv97dRfn6+qqurtXPnTv8577zzjnw+n2bOnBn0mgeKMUYPP/ywVq5cqXfeeUfZ2dldjk+bNk0RERFd2mr//v0qLi7u0la7d+/u8g99/fr1crlc5/0AGW58Pp88Hg/tdJbrr79eu3fvVmFhof9x5ZVX6q677vL/mbbqXn19vQ4ePKi0tDT+Tp1lzpw552058MUXX2j06NGShsDP9AGdHjuIrVixwjidTrNs2TKzb98+8/d///cmPj6+y4zr4a6urs7s2rXL7Nq1y0gy//Ef/2F27dpljh49aoxpXwYWHx9vXn/9dfPpp5+a2267rdtlYFOnTjXbt283W7ZsMePGjRt2S3sfeugh43a7zcaNG7ssL2xsbPSf8+CDD5qsrCzzzjvvmB07dpj8/HyTn5/vP965vPDGG280hYWFZu3atWbkyJHDbnnh9773PbNp0yZz+PBh8+mnn5rvfe97xmazmbfeessYQztdyNmraYyhrTo9/vjjZuPGjebw4cPm/fffN/PmzTNJSUmmsrLSGEM7dfrwww9NeHi4+elPf2oOHDhgfvvb35ro6Gjzyiuv+M8ZzD/TQzaMGGPML37xC5OVlWUcDoeZMWOG2bZtm9UlBdW7775rJJ33uOeee4wx7UvBvv/975uUlBTjdDrN9ddfb/bv39/lPU6dOmXuvPNOExsba1wul/nmN79p6urqLPg2A6e7NpJkXn75Zf85TU1N5tvf/rYZMWKEiY6ONn/9139tysrKurzPkSNHzIIFC0xUVJRJSkoyjz/+uGltbQ3ytxlY3/rWt8zo0aONw+EwI0eONNdff70/iBhDO13IuWGEtmp3xx13mLS0NONwOMyoUaPMHXfc0WXvDNrpjD//+c9mypQpxul0mgkTJpgXXnihy/HB/DPdZowxA9v3AgAA0LOQnDMCAAAGD8IIAACwFGEEAABYijACAAAsRRgBAACWIowAAABLEUYAAIClCCMAAMBShBEAAGApwggAALAUYQQAAFiKMAIAACz1/wPpAdaSB8W9CAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "353K8yRito8c",
    "outputId": "128ac4eb-84da-4f6f-b171-2d5960769630"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def partition(uvalue, sockColumn):\n",
      "            raise ValueError(\"No :class:`Row``, over a a transform when condition and Convert a SQL types at :func:`DataFrame.ifSchema.toAurecordType` to the specified reverse to the given formSt (b, len(goneuter)).map(self._java_matrix_wrapper._jdf.groupBy(lambda v: \n"
     ]
    }
   ],
   "source": [
    "# 使用模型来生成文本\n",
    "begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n",
    "print(''.join(tok.decode(generate_batch(model, begin_text))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "ojDIISnWto8c"
   },
   "outputs": [],
   "source": [
    "# 将层归一化放到在LSTM神经元里面\n",
    "class LSTMLayerNormCell(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        '''\n",
    "        长短期记忆网络的神经元（内含层归一化）\n",
    "        参数\n",
    "        ----\n",
    "        input_size ：int，输入数据的特征长度\n",
    "        hidden_size ：int，隐藏状态的特征长度\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        combined_size = self.input_size + self.hidden_size\n",
    "        # 将四个线性模块放在一起定义，使得代码更加简洁和高效\n",
    "        self.gates = nn.Linear(\n",
    "            combined_size, 4 * self.hidden_size, bias=False)\n",
    "        # 用于门的层归一化\n",
    "        self.ln_gates = nn.LayerNorm(4 * self.hidden_size)\n",
    "        # 用于细胞状态的层归一化\n",
    "        self.ln_c = nn.LayerNorm(self.hidden_size)\n",
    "\n",
    "    def forward(self, inputs, state=None):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        inputs ：torch.FloatTensor\n",
    "            输入数据，形状为(B, I)，其中B表示批量大小，I表示文字特征的长度（input_size）\n",
    "        state ：tuple(torch.FloatTensor, torch.FloatTensor)\n",
    "            (隐藏状态，细胞状态)，两个状态的形状都为(B, H)，其中H表示隐藏状态的长度（hidden_size）\n",
    "        返回\n",
    "        ----\n",
    "        hs ：torch.FloatTensor，隐藏状态，形状为(B, H)\n",
    "        cs ：torch.FloatTensor，细胞状态，形状为(B, H)\n",
    "        '''\n",
    "        B, _ = inputs.shape\n",
    "        if state is None:\n",
    "            state = self.init_state(B, inputs.device)\n",
    "        hs, cs = state\n",
    "        combined = torch.cat((inputs, hs), dim=1)  # (B, I + H)\n",
    "        # 将四个线性模块分开\n",
    "        i, f, c, o = self.ln_gates(self.gates(combined)).chunk(4, 1)\n",
    "        # 输入门\n",
    "        ingate = F.sigmoid(i)      # (B, H)\n",
    "        # 遗忘门\n",
    "        forgetgate = F.sigmoid(f)  # (B, H)\n",
    "        # 输出门\n",
    "        outgate = F.sigmoid(o)     # (B, H)\n",
    "        # 更新细胞状态\n",
    "        ncs = F.tanh(c)            # (B, H)\n",
    "        cs = self.ln_c((forgetgate * cs) + (ingate * ncs))  # (B, H)\n",
    "        # 更新隐藏状态\n",
    "        hs = outgate * F.tanh(cs)                           # (B, H)\n",
    "        return hs, cs\n",
    "\n",
    "    def init_state(self, B, device):\n",
    "        cs = torch.zeros((B, self.hidden_size), device=device)\n",
    "        hs = torch.zeros((B, self.hidden_size), device=device)\n",
    "        return hs, cs\n",
    "\n",
    "class LSTMLayerNorm(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        '''\n",
    "        单层的长短期记忆网络（支持批量计算且内含层归一化）\n",
    "        参数\n",
    "        ----\n",
    "        input_size ：int，输入数据的特征长度\n",
    "        hidden_size ：int，隐藏状态的特征长度\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = LSTMLayerNormCell(self.input_size, self.hidden_size)\n",
    "\n",
    "    def forward(self, inputs, state=None):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        inputs ：torch.FloatTensor\n",
    "            输入数据的集合，形状为(B, T, C)，其中B表示批量大小，T表示文本长度，C表示文字特征的长度（input_size）\n",
    "        state ：tuple(torch.FloatTensor, torch.FloatTensor)\n",
    "            (初始的隐藏状态，初始的细胞状态)，两个状态的形状都为(B, H)，其中H表示隐藏状态的长度（hidden_size）\n",
    "        返回\n",
    "        ----\n",
    "        hidden ：torch.FloatTensor，所有隐藏状态的集合，形状为(B, T, H)\n",
    "        '''\n",
    "        re = []\n",
    "        B, T, C = inputs.shape\n",
    "        inputs = inputs.transpose(0, 1)  # (T, B, C)\n",
    "        for i in range(T):\n",
    "            state = self.lstm(inputs[i], state)\n",
    "            # 只记录隐藏状态，state[0]的形状为(B, H)\n",
    "            re.append(state[0])\n",
    "        result_tensor = torch.stack(re, dim=0)  # (T, B, H)\n",
    "        return result_tensor.transpose(0, 1)    # (B, T, H)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "wJ2ypYx6to8d"
   },
   "outputs": [],
   "source": [
    "class CharLSTMLayerNorm(nn.Module):\n",
    "\n",
    "    def __init__(self, vs):\n",
    "        '''\n",
    "        三层的长短期记忆网络（内嵌层归一化）\n",
    "        参数\n",
    "        ----\n",
    "        vs ：int，字典大小\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.emb_size = 256\n",
    "        self.hidden_size = 128\n",
    "        self.embedding = nn.Embedding(vs, self.emb_size)\n",
    "        self.dp = nn.Dropout(0.4)\n",
    "        self.lstm1 = LSTMLayerNorm(self.emb_size, self.hidden_size)\n",
    "        self.lstm2 = LSTMLayerNorm(self.hidden_size, self.hidden_size)\n",
    "        self.lstm3 = LSTMLayerNorm(self.hidden_size, self.hidden_size)\n",
    "        self.h2o = nn.Linear(self.hidden_size, vs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，当前字母在字典中的位置，形状为(B, T)\n",
    "        返回\n",
    "        ----\n",
    "        output ：torch.FloatTensor，预测结果的logits，形状为(B, T, vs)\n",
    "        '''\n",
    "        emb = self.embedding(x)       # (B, T,  C)\n",
    "        h = self.dp(self.lstm1(emb))  # (B, T,  H)\n",
    "        h = self.dp(self.lstm2(h))    # (B, T,  H)\n",
    "        h = self.dp(self.lstm3(h))    # (B, T,  H)\n",
    "        output = self.h2o(h)          # (B, T, vs)\n",
    "        return output\n",
    "\n",
    "model_norm = CharLSTMLayerNorm(len(tok.char2ind)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "khqYG4hYto8d",
    "outputId": "154fe844-b093-41b7-b31f-c26bf703a460"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch  0: train loss 1.1342, test loss 1.2981\n",
      "epoch  1: train loss 0.9863, test loss 1.1823\n",
      "epoch  2: train loss 0.9283, test loss 1.1456\n",
      "epoch  3: train loss 0.8908, test loss 1.1117\n",
      "epoch  4: train loss 0.8742, test loss 1.1144\n",
      "epoch  5: train loss 0.8481, test loss 1.0984\n",
      "epoch  6: train loss 0.8359, test loss 1.0962\n",
      "epoch  7: train loss 0.8202, test loss 1.0890\n",
      "epoch  8: train loss 0.8229, test loss 1.0829\n",
      "epoch  9: train loss 0.8128, test loss 1.0881\n"
     ]
    }
   ],
   "source": [
    "l_norm = train_lstm(model_norm, optim.Adam(model_norm.parameters(), lr=learning_rate),\n",
    "                    train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 448
    },
    "id": "GngdCJHkto8d",
    "outputId": "ab705301-e8fb-4ba4-a6f8-c9075dce6c5d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7b17184001f0>]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7OklEQVR4nO3deXTU9b3/8ddMJjNZZ7IvhIRF9j2yBlvxKoqUKnTxWq4ttle9PyzeK9XThd5eW+uvN97j8f7q7YJS69KrlFZbsHUBEQUXFgGJEpB9SYAskGUm6ySZ+f7+SDIQICHb5Jswz8c5c9p8l8x7PqWZ1/lsX4thGIYAAABMYjW7AAAAENoIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmIowAgAATEUYAQAApiKMAAAAU9nMLqAz/H6/zpw5o9jYWFksFrPLAQAAnWAYhqqqqjRo0CBZre33fwyIMHLmzBllZmaaXQYAAOiGwsJCDR48uN3zAyKMxMbGSmr+ME6n0+RqAABAZ3g8HmVmZga+x9szIMJI69CM0+kkjAAAMMBcaYoFE1gBAICpCCMAAMBUhBEAAGCqHoWRxx9/XBaLRcuXL+/wuldeeUVjxoxRRESEJk6cqDfffLMnbwsAAK4i3Q4jO3fu1DPPPKNJkyZ1eN3WrVu1ePFi3XPPPdqzZ48WLVqkRYsWKT8/v7tvDQAAriLdCiPV1dW666679Lvf/U7x8fEdXvvUU0/p1ltv1fe//32NHTtWjz32mK699lr9+te/7lbBAADg6tKtMLJs2TItWLBAc+fOveK127Ztu+S6efPmadu2be3e4/V65fF42rwAAMDVqcv7jKxZs0affPKJdu7c2anri4uLlZqa2uZYamqqiouL270nNzdXjz76aFdLAwAAA1CXekYKCwv14IMP6uWXX1ZERESwatKKFSvkdrsDr8LCwqC9FwAAMFeXekZ2796t0tJSXXvttYFjPp9P77//vn7961/L6/UqLCyszT1paWkqKSlpc6ykpERpaWntvo/D4ZDD4ehKaQAAYIDqUs/ITTfdpL179yovLy/wmjZtmu666y7l5eVdEkQkKScnR5s2bWpzbOPGjcrJyelZ5QAA4KrQpZ6R2NhYTZgwoc2x6OhoJSYmBo4vWbJEGRkZys3NlSQ9+OCDmjNnjp588kktWLBAa9as0a5du7Rq1ape+ggAAGAg6/UdWAsKClRUVBT4efbs2Vq9erVWrVqlyZMn69VXX9W6desuCTVm+P2Hx/Wzv+3TweIqs0sBACBkWQzDMMwu4ko8Ho9cLpfcbnevPrX3K7/9SHsKKrXqW1N1y/j257AAAICu6+z3d0g/myas5ZHG/n4fxwAAuHqFdBixBsIIaQQAALOEdhhp+fQ+ukYAADBNSIeRMCs9IwAAmC2kwwjDNAAAmI8wIsnnN7kQAABCWIiHkeb/pGcEAADzhHQYCcwZYQIrAACmCekwYmWfEQAATEcYkeRjmAYAANOEdBhhmAYAAPOFdBixss8IAACmC+0w0rKahh1YAQAwT0iHkTA2PQMAwHQhHUbOD9OYXAgAACEstMMIwzQAAJgupMMIq2kAADBfSIcRC5ueAQBgupAOI2FsegYAgOlCO4y0DNMYhBEAAEwT0mHEwgRWAABMF9JhhGEaAADMF9phJDBMY3IhAACEsJAOI62raRimAQDAPCEdRsJaPj1hBAAA84R2GLGwmgYAALOFdBixMIEVAADThXQYaZ3A6vObXAgAACGMMCKGaQAAMFNIhxE2PQMAwHwhHUbY9AwAAPOFdBixWtj0DAAAs4V2GLGy6RkAAGYL6TAS1jJnxE/XCAAApgnpMNLaM0IYAQDAPKEdRng2DQAApgvpMBIW6BkxuRAAAEJYSIcRa+ucEdIIAACmCfEwwj4jAACYLaTDCMM0AACYL6TDSGvPCMM0AACYJ7TDCJueAQBgupAOI63PpmGfEQAAzBPSYcTKDqwAAJgutMMIwzQAAJgutMOIhdU0AACYrUthZOXKlZo0aZKcTqecTqdycnL01ltvtXv9Cy+8IIvF0uYVERHR46J7S1jLp2eYBgAA89i6cvHgwYP1+OOPa+TIkTIMQy+++KIWLlyoPXv2aPz48Ze9x+l06uDBg4GfLS29Ef2BlQmsAACYrkth5Lbbbmvz8y9+8QutXLlS27dvbzeMWCwWpaWldb/CIDr/oDyTCwEAIIR1e86Iz+fTmjVrVFNTo5ycnHavq66u1pAhQ5SZmamFCxdq3759V/zdXq9XHo+nzSsYAjuwMmkEAADTdDmM7N27VzExMXI4HFq6dKnWrl2rcePGXfba0aNH67nnntNrr72ml156SX6/X7Nnz9apU6c6fI/c3Fy5XK7AKzMzs6tldgrDNAAAmM9iGF37Jm5oaFBBQYHcbrdeffVVPfvss9qyZUu7geRCjY2NGjt2rBYvXqzHHnus3eu8Xq+8Xm/gZ4/Ho8zMTLndbjmdzq6U26Edx8p056rtGp4crXcfvqHXfi8AAGj+/na5XFf8/u7SnBFJstvtGjFihCRp6tSp2rlzp5566ik988wzV7w3PDxc2dnZOnLkSIfXORwOORyOrpbWZQzTAABgvh7vM+L3+9v0YnTE5/Np7969Sk9P7+nb9gorT+0FAMB0XeoZWbFihebPn6+srCxVVVVp9erV2rx5szZs2CBJWrJkiTIyMpSbmytJ+vnPf65Zs2ZpxIgRqqys1BNPPKGTJ0/q3nvv7f1P0g3nV9OQRgAAMEuXwkhpaamWLFmioqIiuVwuTZo0SRs2bNDNN98sSSooKJDVer6zpaKiQvfdd5+Ki4sVHx+vqVOnauvWrZ2aX9IXeFAeAADm6/IEVjN0dgJMV+WfduvLv/pQqU6Hdvx4bq/9XgAA0Pnv75B+Nk2YlU3PAAAwG2FE0gDoHAIA4KoV0mGkJYvIRxgBAMA0IR5GWE0DAIDZQjqMnB+mMbkQAABCWEiHEXpGAAAwX2iHkdbVNHSNAABgmpAOI62bnrGaBgAA84R0GAmspmGYBgAA04R2GLngQXn0jgAAYI6QDiOtwzQSK2oAADBLSIcR6wVhhEmsAACYI7TDyAWfnnkjAACYI7TDCMM0AACYLqTDSOsOrBLDNAAAmCWkw4jtgjDS5PObWAkAAKErtMNImDWw10gDYQQAAFOEdBiRpPCw5iZoaCKMAABghpAPI3YbYQQAADMRRlp6Rhp9TGAFAMAMhBF6RgAAMFXIh5HAnBEmsAIAYIqQDyP0jAAAYK6QDyPhgTkjhBEAAMwQ8mGEnhEAAMxFGAlr3vWMnhEAAMxBGLExgRUAADMRRtiBFQAAU4V8GGFpLwAA5gr5MMIEVgAAzEUYYWkvAACmIozQMwIAgKlCPoycnzPCg/IAADBDyIcRekYAADBXyIcRtoMHAMBcIR9G6BkBAMBchBG2gwcAwFSEEXpGAAAwVciHEXZgBQDAXCEfRugZAQDAXIQRekYAADAVYcTG0l4AAMxEGAljmAYAADOFfBhhO3gAAMwV8mEk0h4mSapraDK5EgAAQlPIh5EYh02SVF1PGAEAwAxdCiMrV67UpEmT5HQ65XQ6lZOTo7feeqvDe1555RWNGTNGERERmjhxot58880eFdzbYiOaw0iVlzACAIAZuhRGBg8erMcff1y7d+/Wrl27dOONN2rhwoXat2/fZa/funWrFi9erHvuuUd79uzRokWLtGjRIuXn5/dK8b0hpiWMVHub5PczbwQAgL5mMQyjR9/ACQkJeuKJJ3TPPfdccu7OO+9UTU2NXn/99cCxWbNmacqUKXr66ac7/R4ej0cul0tut1tOp7Mn5V6ivtGnMf+xXpKU/+i8wLANAADomc5+f3d7zojP59OaNWtUU1OjnJycy16zbds2zZ07t82xefPmadu2bR3+bq/XK4/H0+YVLA6bVTZr88Pyquobg/Y+AADg8rocRvbu3auYmBg5HA4tXbpUa9eu1bhx4y57bXFxsVJTU9scS01NVXFxcYfvkZubK5fLFXhlZmZ2tcxOs1gsgXkjTGIFAKDvdTmMjB49Wnl5edqxY4fuv/9+3X333dq/f3+vFrVixQq53e7Aq7CwsFd//8Va5414CCMAAPS5Lk+QsNvtGjFihCRp6tSp2rlzp5566ik988wzl1yblpamkpKSNsdKSkqUlpbW4Xs4HA45HI6ultZtsY5wSXWqZkUNAAB9rsf7jPj9fnm93suey8nJ0aZNm9oc27hxY7tzTMzS2jPCnBEAAPpel3pGVqxYofnz5ysrK0tVVVVavXq1Nm/erA0bNkiSlixZooyMDOXm5kqSHnzwQc2ZM0dPPvmkFixYoDVr1mjXrl1atWpV73+SHnAyZwQAANN0KYyUlpZqyZIlKioqksvl0qRJk7RhwwbdfPPNkqSCggJZrec7W2bPnq3Vq1frJz/5iX784x9r5MiRWrdunSZMmNC7n6KHWpfzVhFGAADoc10KI7///e87PL958+ZLjt1xxx264447ulRUX4uNCJfELqwAAJgh5J9NI0nOyOZM5q5tMLkSAABCD2FEUnyUXZJUXssEVgAA+hphRFJSTPMy4rLqy68KAgAAwUMYkZQQ3dIzUsMwDQAAfY0wovNhpIwwAgBAnyOMSEqMaQ4jFTUN6uFDjAEAQBcRRnS+Z6TJb8hTx/JeAAD6EmFEksMWFtj4rKyGSawAAPQlwkgLJrECAGAOwkgLJrECAGAOwkiLRHpGAAAwBWGkBcM0AACYgzDSIqFleW9ZNWEEAIC+RBhpkRiYM8JqGgAA+hJhpEVidPPzaRimAQCgbxFGWjBMAwCAOQgjLVhNAwCAOQgjLS5cTcPzaQAA6DuEkRatc0YafH5Ve3k+DQAAfYUw0iLSHqbI8DBJDNUAANCXCCMXYEt4AAD6HmHkAoktK2rKWVEDAECfIYxcgC3hAQDoe4SRCzBMAwBA3yOMXCCwJXw1W8IDANBXCCMXSGBLeAAA+hxh5AKtE1gZpgEAoO8QRi7AlvAAAPQ9wsgFWE0DAEDfI4xcoHVL+LIaJrACANBXCCMXSGiZM1Lf6FdtA8+nAQCgLxBGLhBtD5Pd1twkZezCCgBAnyCMXMBisTCJFQCAPkYYuQiTWAEA6FuEkYuwJTwAAH2LMHIRtoQHAKBvEUYuEtgSvpaeEQAA+gJh5CLxUeGSJHdto8mVAAAQGggjF4lrCSMV9IwAANAnCCMXiYtqnjNSSc8IAAB9gjBykdaeEXcdYQQAgL5AGLlIXGRzzwjDNAAA9A3CyEVae0YYpgEAoG8QRi7SGka8TX7VN/pMrgYAgKsfYeQiMQ6bbFaLJIZqAADoC4SRi1gsFoZqAADoQ10KI7m5uZo+fbpiY2OVkpKiRYsW6eDBgx3e88ILL8hisbR5RURE9KjoYHNFEkYAAOgrXQojW7Zs0bJly7R9+3Zt3LhRjY2NuuWWW1RTU9PhfU6nU0VFRYHXyZMne1R0sJ3fa4RhGgAAgs3WlYvXr1/f5ucXXnhBKSkp2r17t66//vp277NYLEpLS+tehSZo3RK+kr1GAAAIuh7NGXG73ZKkhISEDq+rrq7WkCFDlJmZqYULF2rfvn09edugc0WyCysAAH2l22HE7/dr+fLluu666zRhwoR2rxs9erSee+45vfbaa3rppZfk9/s1e/ZsnTp1qt17vF6vPB5Pm1dfCvSMMEwDAEDQdWmY5kLLli1Tfn6+Pvzwww6vy8nJUU5OTuDn2bNna+zYsXrmmWf02GOPXfae3NxcPfroo90trcdYTQMAQN/pVs/IAw88oNdff13vvfeeBg8e3KV7w8PDlZ2drSNHjrR7zYoVK+R2uwOvwsLC7pTZba7WCax19IwAABBsXeoZMQxD//qv/6q1a9dq8+bNGjZsWJff0Ofzae/evfrSl77U7jUOh0MOh6PLv7u3tA7TVNAzAgBA0HUpjCxbtkyrV6/Wa6+9ptjYWBUXF0uSXC6XIiMjJUlLlixRRkaGcnNzJUk///nPNWvWLI0YMUKVlZV64okndPLkSd177729/FF6T+vD8tyEEQAAgq5LYWTlypWSpBtuuKHN8eeff17f/va3JUkFBQWyWs+P/lRUVOi+++5TcXGx4uPjNXXqVG3dulXjxo3rWeVBFJgzwjANAABBZzEMwzC7iCvxeDxyuVxyu91yOp1Bf79TFbX6wn+9J7vNqoOP3SqLxRL09wQA4GrT2e9vnk1zGa07sDY0+VXf6De5GgAArm6EkcuItocFntzLUA0AAMFFGLmM5if3NveOVNQwiRUAgGAijLSDSawAAPQNwkg74iKbwwjLewEACC7CSDsCwzSEEQAAgoow0g6GaQAA6BuEkXYwTAMAQN8gjLQjPrp1mIaeEQAAgokw0g5XS89IJT0jAAAEFWGkHfFR9IwAANAXCCPtSHU6JEklHq/JlQAAcHUjjLQjzRUhSSp212sAPEsQAIABizDSjpTYCFksUoPPr/IahmoAAAgWwkg77DarkmKah2qK3PUmVwMAwNWLMNKB9AuGagAAQHAQRjqQ5mwOI0UewggAAMFCGOnA+Z6ROpMrAQDg6kUY6UCaK1KSVFRJzwgAAMFCGOlAa88IE1gBAAgewkgHAnuNMGcEAICgIYx0YFDrMI27jo3PAAAIEsJIB1JatoSvb/TLXccD8wAACAbCSAciwsOUGN38wDzmjQAAEByEkStIY+MzAACCijByBayoAQAguAgjV5DGxmcAAAQVYeQK0gMraugZAQAgGAgjV9D6fBr2GgEAIDgII1fAnBEAAIKLMHIFrXNGiirZ+AwAgGAgjFxBaxipafCpyttkcjUAAFx9CCNXEGW3yRUZLom9RgAACAbCSCcwbwQAgOAhjHRCOnuNAAAQNISRTkhjrxEAAIKGMNIJ6TyfBgCAoCGMdEIac0YAAAgawkgn0DMCAEDwEEY64fxqGiawAgDQ2wgjndA6gdVT36QaNj4DAKBXEUY6IcZhU6zDJokH5gEA0NsII52UxrwRAACCgjDSSa1h5Ewl80YAAOhNhJFOGsTGZwAABAVhpJMGxzeHkcLyWpMrAQDg6kIY6aSsxChJ0knCCAAAvapLYSQ3N1fTp09XbGysUlJStGjRIh08ePCK973yyisaM2aMIiIiNHHiRL355pvdLtgsWQnNYYSeEQAAeleXwsiWLVu0bNkybd++XRs3blRjY6NuueUW1dTUtHvP1q1btXjxYt1zzz3as2ePFi1apEWLFik/P7/Hxfel1jBS7KlXfaPP5GoAALh6WAzDMLp789mzZ5WSkqItW7bo+uuvv+w1d955p2pqavT6668Hjs2aNUtTpkzR008/3an38Xg8crlccrvdcjqd3S23RwzD0MSfva1qb5Peeeh6jUiJNaUOAAAGis5+f/dozojb7ZYkJSQktHvNtm3bNHfu3DbH5s2bp23btrV7j9frlcfjafMym8ViCfSOnDjHUA0AAL2l22HE7/dr+fLluu666zRhwoR2rysuLlZqamqbY6mpqSouLm73ntzcXLlcrsArMzOzu2X2qmtSYiRJR89Wm1wJAABXj26HkWXLlik/P19r1qzpzXokSStWrJDb7Q68CgsLe/09uuOa5GhJhBEAAHqTrTs3PfDAA3r99df1/vvva/DgwR1em5aWppKSkjbHSkpKlJaW1u49DodDDoejO6UF1fDk5p6RY2fbn7ALAAC6pks9I4Zh6IEHHtDatWv17rvvatiwYVe8JycnR5s2bWpzbOPGjcrJyelapf0APSMAAPS+LvWMLFu2TKtXr9Zrr72m2NjYwLwPl8ulyMjmHUqXLFmijIwM5ebmSpIefPBBzZkzR08++aQWLFigNWvWaNeuXVq1alUvf5TgG57U3DNSUduo8poGJUTbTa4IAICBr0s9IytXrpTb7dYNN9yg9PT0wOtPf/pT4JqCggIVFRUFfp49e7ZWr16tVatWafLkyXr11Ve1bt26Die99leR9jBlxDWHLnpHAADoHV3qGenMliSbN2++5Ngdd9yhO+64oytv1W8NT47W6co6HTtbrelD21/SDAAAOodn03TRNcmty3uZxAoAQG8gjHRRYK+RUoZpAADoDYSRLhqW2Lyi5kQZPSMAAPQGwkgXDUlseXpvRZ38/m4/1gcAALQgjHRRuitCNqtFDU1+FXvqzS4HAIABjzDSRbYwqzJbH5jHUA0AAD1GGOmG1qf3FpTx9F4AAHqKMNINw5LYFh4AgN5CGOmGUamxkqRDJYQRAAB6ijDSDaPTmvcaOVRSZXIlAAAMfISRbhiR0twzUuSul6e+0eRqAAAY2Agj3eCKDFe6K0KS9PkZj8nVAAAwsBFGumlKZpwk6ZOCSlPrAABgoCOMdNPUIfGSpN0nK0yuBACAgY0w0k3XtoSRTwoqZBhsCw8AQHcRRrpp/CCn7DarymsadILNzwAA6DbCSDc5bGGalOGSxFANAAA9QRjpAeaNAADQc4SRHgjMGyGMAADQbYSRHrg2qzmMHCqtYvMzAAC6iTDSA8mxDmUlRMkwpDz2GwEAoFsIIz00rWWo5qOj50yuBACAgYkw0kP/MCZFkrRxf4nJlQAAMDARRnpozuhkhYdZdOxsjY6erTa7HAAABhzCSA85I8I1a3iiJOkdekcAAOgywkgvuHlcqiSGagAA6A7CSC+4aWxzGPmkoELuWpb4AgDQFYSRXpARF6lrkqPlN6Rtx1hVAwBAVxBGeskXRiRJkj44TBgBAKArCCO95AsjkyVJHx4hjAAA0BWEkV4ya3iCwqwWnSyrVWF5rdnlAAAwYBBGeklsRLiyM+MkSZsPnTW3GAAABhDCSC+a27LE9++fnjG5EgAABg7CSC+6ffIgWSzSx8fLVeSuM7scAAAGBMJILxoUF6lJGS5J0o5j5SZXAwDAwEAY6WXThyZIknaeIIwAANAZhJFeNq0ljHx45JwamvwmVwMAQP9HGOllOcMTFRth08myWj216ZDZ5QAA0O8RRnqZKypcuV+dKEl6eUeBvE0+kysCAKB/I4wEwfwJ6Up3RaiytpEn+QIAcAWEkSAIs1r09amDJUl/2llocjUAAPRvhJEg+cdpmZKaJ7KeLKsxuRoAAPovwkiQZCZE6YbRyTIM6XcfHDO7HAAA+i3CSBD9n+uvkST9ZfdpVXubTK4GAID+iTASRLOGJ2hYUrTqGn16a2+R2eUAANAvEUaCyGKx6GvXZkiSfrv5qOobWeYLAMDFuhxG3n//fd12220aNGiQLBaL1q1b1+H1mzdvlsViueRVXFzc3ZoHlCWzhyrV6dDxczX648cFZpcDAEC/0+UwUlNTo8mTJ+s3v/lNl+47ePCgioqKAq+UlJSuvvWA5IwI13dvGCFJevTv+7X7ZIXJFQEA0L/YunrD/PnzNX/+/C6/UUpKiuLi4rp839Vg4ZRBeuz1/WryG7rvD7u048c3KTyMETIAAKQ+nDMyZcoUpaen6+abb9ZHH33U4bVer1cej6fNayCLi7LrN3ddK0kqr2nQ1qNlJlcEAED/EfQwkp6erqefflp/+ctf9Je//EWZmZm64YYb9Mknn7R7T25urlwuV+CVmZkZ7DKDbt74NH1r1hBJ0t8/PWNyNQAA9B8WwzCMbt9ssWjt2rVatGhRl+6bM2eOsrKy9L//+7+XPe/1euX1egM/ezweZWZmyu12y+l0drdc0+04VqY7V22XJM2+JlEvfGeG7DaGawAAVyePxyOXy3XF729TvglnzJihI0eOtHve4XDI6XS2eV0Npg9NUER4c5NvPVqmDw6fNbkiAADMZ0oYycvLU3p6uhlvbSqr1aKf3jY+8PNf95w2sRoAAPqHLq+mqa6ubtOrcfz4ceXl5SkhIUFZWVlasWKFTp8+rT/84Q+SpF/+8pcaNmyYxo8fr/r6ej377LN699139fbbb/fepxhAFs/I0siUGH396W1647MizRp2Qt/KGWp2WQAAmKbLPSO7du1Sdna2srOzJUkPPfSQsrOz9cgjj0iSioqKVFBwfnOvhoYGPfzww5o4caLmzJmjTz/9VO+8845uuummXvoIA8+0oQl64B+a9x557PXPdaB4YK8WAgCgJ3o0gbWvdHYCzEBiGIbufXGXNh0o1ejUWK1dNltR9i53VAEA0G/16wmsaF6J9F9fn6SkGIcOllRp3CMb9OruU2aXBQBAnyOMmCgpxqGV37w28PNjr+9XRU2DiRUBAND3CCMmmz40QfmPzlNGXKTcdY164u2DZpcEAECfIoz0AzEOm574+iRJ0uodBfr53/fLXddoclUAAPQNwkg/MXtEkh68aaQk6bmPjmvyo2/rL8whAQCEAMJIP/K9m0fpG9PPP4fnB3/5TJ+dqjSvIAAA+gBhpJ95dOF4/fqfsjVzWIJ8fkMP/flTna6s07lq75VvBgBgAGKfkX6qoqZBt/zyfZ2tag4hidF2vff9G+SMCDe5MgAAOod9Rga4+Gi7/utrEwM/l9U06H/eOawBkB0BAOgSwkg/duOYVN1/wzWBn5/98Lh+8Opn8vkJJACAqwdhpJ/74a1jdPD/3qpvzsqS1SK9svuUbvl/W7Tp8xKzSwMAoFcwZ2QAWZ9fpH/94x41+gxZLdKd07P0pYlp+uLIZLNLAwDgEswZuQrdOiFdb39vjmYOS5DfkP74cYG+8/xOnThXY3ZpAAB0G2FkgBmWFK3V983SyruuVVKMQ01+Qz949TN9XuRhcisAYEAijAxAYVaL5k9M10v3zpDDZtXHJ8o1/6kPtOi3W3W2yitvk8/sEgEA6DTmjAxw246W6b83HtSnhW41+PySJLvNqrljU3TjmFR9JTtDYVaLyVUCAEJRZ7+/CSNXiWNnq/XNZ3fojLu+zfG7c4bo0YUTTKoKABDKOvv9bevDmhBEw5NjtG7ZdXorv1gVtQ06ca5G6/LO6MVtJzVtaILKaxp0prJOS+dco7iocFks9JYAAPoHekauYv+1/oBWbj56yfEJGU79x4Jxmjk80YSqAAChgqW90MM3j9Kt49MuOZ5/2qM7V23X+vwi/f7D4/rmszv0v9tPmlAhAAD0jFz1/H5D24+VaXRarGq8Pr247YR+/+FxSc2rci7cWv6Jr0/SV68dzIRXAECvYAIr2tXQ5NfXn96qz065ZbVIFz7qJsZh010zs/TAjSP0p52FmpjhYjgHANAthBF06FRFrX75zmEtmJQue5hVdz2747LX2cOseuobU5RzTaLioux9XCUAYCAjjKDLTlXU6k87C/Wrd49c9vyEDKcmZsRpREqM/vm6oazIAQB0iDCCblu5+ajW7CzQz24fr80HSvXG3iKdq2645LpHvjxO6/cVyxUZru/NHaVxg/jfBgBwHmEEvepslVf/8r+7tKeg8rLn7Tar/mlGluZPSNP0oQmyMgkWAEIeYQS9zuc3VOKp17sHSrVy81GdrqyTJA1Pitaxi54cPDwpWj/58lidLKvVi1tPaP7EdP1g3miGdgAghBBGEHQFZbUKt1mU5ozQh0fO6XcfHNf7h862e/0t41J1y/g0naqo1bzxaQqzWjQyJYaAAgBXKcIITJF/2q2tR8/pQFGV/rrntCLCrRqREqN9Zzy63L+0CRlO3TgmVXNGJWlMmlPRDp5QAABXC8IITFdR0yBHuFVRdpv2n/Ho1+8d1pt7izu8Z2RKjK5JjtGtE9I0d1yqmnx+fXy8XHPHpgbmoRS56xQbEa4YggsA9GuEEfRbf9h2QnmFlfrG9Cy9/tkZbT1apiOl1ZdclxhtV1lN8yqecelOjUiJ0eaDpfLUNykuKlzfveEaTR4cx6ZsANBPEUYwYDT5/Hr3QKniouzKK6zQn3YW6kRZbZut6jsyyBWhBp9fQxOjdeuENPkNQyNSYvTFkckKD7v08UvvHSjVxyfK9dDNoy57HgDQOwgjGNDqG316aftJfXD4nLYcOiuLRfrWrCEqq2nQG58Vdep3xEeFKybCprLqBk3IcCnMYlGT36+dJyokNe+Tkhhj156CSmVnxSk7M15Hz1br+lHJPJ8HAHoBYQRXjcLyWrnrGjUhwyVJctc26tkPj+lr1w7WurzT+s17RzRnVLIkqcbrU3KsQ1uPlulctbdb7zc23an/c/1wXTciSbERNkWEh8kwDB0urdawpGh6UwCgkwgjCBn1jT5FhIe1Odbk8+vTU5V6e1+Jzrjr1eTz66385smzQxKjdLKsVpIUG2HTzGGJeufzknZ/f2R4mGIjbCqt8mrGsAT97lvTtK/IrYkZLlksFp2prFNmfJQiwq0sUwaACxBGgMtoDS7r84skWTRvfKosFovW5xfr5R0ndfvkQXpjb5E+O+VWec2lW+BLUniYRY2+S/9vE+OwKSHaroLyWo1Ji9XYdKemD02QxSLdPC5VPr+hVGfEZX9neU2DDMNQYoyjNz8uAJiKMAL0gGEYqmv0ydvo18nyWlXUNMhike5/6RPVNfq6/XtTnQ7dOCZFTT5DO46Xq9HnV0ZcpHadrJDVIt00NlW3TR6kGm+T5oxK1pt7izR9aIImZ8bpZFmNBsVFthkmavT5ZbVYmOMCoF8ijABBcLikSoUVtcqMj9KGfcU6VVGn7Kw45QxP0rFz1frxX/fqjLteSTGObs9ZuZjVIl2THKPDLcufJ2fGKdxqkb9lHktyjEPfmJGpiRlxGpseq4PFVZo2NEEHij0amRIru405LgDMQRgBTNDk86u0yqtBcZHafbJcb+8v0c1jU/XC1hOaPjRBnxRUaMO+YmUlROlfbxypukafVvx1r8KsFn01O0N7CipVWlWvitrGXqknIy5S8yekyZC060S5Up0RGpUaq3nj09Tkb14O7alv1LGzNZoxLEFR9jBt2Feisemxyj/tkd8wdNvkQb1SC4DQQxgBBoi9p9yyWBRYLSRJ+8649dkpt76SnaGjZ6u1v2U7/df3FulAkUelVV5NzHDptsnpKvV4lX/Gre3HyoNW44JJ6TpQ5FFGfJTGpTtV39i8aqn1z0dsRLimD03Q3z49o3njU5WdFX/Z39Pk88sWZpXPb+hQSZUy4iPljAgPWt0AzEUYAULMR0fOySLpUEmVjp+r0bdyhmjniQpt3F+ik2U1yrkmUSOSY7R+X7F2HC+Xw2ZVfaM/KLXkDE/U4dIqnatungQcHxUub5NfDU1+jR/kVFV9k46dq5ErMlzPfGuqou02HSqpUmKMXRHhYUqMtmt/kUdTMuM0JDFahmHIMBR4JMCF3HWNinXYLnsOgLkIIwDa1ejzy2a1KP+0Rw0+v65Jjtabe4vV0OTTibJabdxfosUzMjVpcJy2HStTYXmtGn1+OWxh2n2yQvWNvsBW/cGWHOuQ1SKVVnmVFOPQNcnR+uLIZEWGh8lvGPqv9Qd0TXKM/vkLw1TjbZIrMlxl1Q1y1zXK2+RTVmK0bFaLvjwpXdF2m46crdaZyjptPVomb6NP/zRziEanxQbez+83tO1YmcakxbK6CeghwgiAoGj9k/HRkTJ9eqpSo1JjlZ0VpxJPvfIKK/XS9gKVeupVVtOgUakxOlRSrWuSo3XvF4fLU9eoJzYc1JxRycr92kQte/kT7TxRIYtFskhqfQKA1SKNSXPqYElVpx8L0BmxDpuqvE2XHLdZLRqb7tTZKq+KPfWSpCh7mP5hTIpOnKtRVX2TZgxL0IyhCfI2+XTHtEwdPVsth82qGEe4XJHhMmToQHGVMuOjtPlgqY6UVqumoUkLJg7SoZIqXTciSSNSYi5pyypvk2IdNr3zeamys+KU1MUA5Pcb9Aqh3yKMAOgXDMNosxlctbcp8MTlJp9fH58oV7orUhlxkTpdWaeGJr9ckeFKc0XoxLka/fHjAmVnxWtKZpz++HGB3jtYKqvFos9OVcqQNG9cmtLjIpRXWNm8zNli0YmyGpVWtb+ayW6z6sbRKVq/r+OnSPe2xTOy9MHhs/L5DS2ckqGPjpzT3tPuwPnJg1367j+M0B+2ndDCyRk6XVknSfrypHT9/sPjeufzUk0e7NKCSelaOCVDnxRUaPmaPCXFOvT1qYM1bUh8IFQ9seGAZg1P1FevHXzFugzD0Gen3BqaGC1XVMdzeNx1jaqsbdCQxOg295+t8iqlnX10ELoIIwCuaj6/ocrahnaHUorcdfrkZKU+O1Wp26cM0t/yzmjOqGRNyYpTZHiYLBaLth0t0wtbj+v2yRk6fq5aO09UKDbCpvkT0hVpt2r7sXKt23Na5TUNckaGKzI8LBAQOhIbYVNSjEPHz9X09scOiAi/dM6PxSJ9edIg7TpRriJ3cw/PqNQY1TX6NH1ogm6fPEgvbS9QZW2DYiNsiokIlz3Mqt0ny3WirFaxDptuHpeqf5yeqSOl1Sosr9U9XximI6XVeu9g88MsV+8o0OnKOv3jtMEalhQju82qY2er9fKOAi2ekalfLJqoBp9f3pZQ2arUU6+dJyp0y/hUVdU3Ka+wQteNSJI97PI7FxuGofpGv8LDmsPlNckx7HA8ABFGAKCXXNi7sz6/SKcq6vT1qYOVf9qj8toGnThXo72n3frZ7eMVbrUoMcahMKtFz35wTG/vL9GMoQn67eYjumVcmhZOGaTdJyv06ienlBkfpZ8sGKsVf92rY+dqZLFIyTGONr06zgibPPXNQ0uR4WFtNt0Ls1qUEG3X6NRY1TY06ZOCyj5tl460fpYGn183j03VO5+XqKK2UYNcETrTEpQkaWKGS7OGJ6i8plETM5w6466Xz29o98kKfXqqUtF2m6q9TUqOdWjmsATFR9lVXtugoYlROlFWq4KyWv33P05WZV2jrBaL0l0RavT5leaKkN/f3AtW4qnX7z88rgkZTr2yq7ndZ49I1A2jUhRht8phC1NFTYPePVCqhGi75oxKVlV9kxp8flksuuzQWW1DUyDUXqh1snVtoy/QAxjKghZG3n//fT3xxBPavXu3ioqKtHbtWi1atKjDezZv3qyHHnpI+/btU2Zmpn7yk5/o29/+dqffkzACYKC73DOUWrWGnQtDz9Gz1fL7DY1MjdWhkipV1Tdp6pB41TY0qdFn6PMij8amOdsMq2w+WKo39xZpeHKM5oxK1qr3j2lMWqwq6xr1/qGzKq3yKmd4om4ck6JtR8vU4POr2F2vw6XV+uWdU1RR26Dff3hcx85WKy7KroqaBlW1DKsNTYpS/mlPm7qHJ0XrWBB7f3qDxSJd6VsuzRkhd11jIOhZLefnL0nS3LEpWj53lLxNfh0uqdLnRR6t2VmoselOfSU7Q+8dLFVBWa3OVXsVZrXIFRmu0iqv/u+iCcpKiNKB4irNGZWsx986oE9PVWr2NYn6znXDFOOw6Sfr8tXQ5Fe0w6azVfWqrGuUKzJcEzNcWj53lKyW5tAZF2XX3lNupTodgeGwRp9f5TUNslktOlBcJVdkuMYPav6ObGiZcN5ZHf377ImghZG33npLH330kaZOnaqvfvWrVwwjx48f14QJE7R06VLde++92rRpk5YvX6433nhD8+bN69UPAwDouovn9bRq8vn1weFzGj/IKWdkuF7ZfUqzr0lUYrRdkfYwOWxhctc2aueJck0bGq8/7yrUybJaJcY4NDHDpWh7mNx1jdp9skJR9jDNGJaoA8WewETfcJtFb+8r0d7TbsVH2XWu2qsJGU4lxTiUGO3QDaOTFR9lV1KsXTuOletAcZViI2w6UFyl9w+d7fV2SIi2t/tMKjNZLFKE7XyvWFxUuOIiw3WuukHVF03IHpYULW+jT2fc9RqWFC1nZLjCLFJMRLjio8J13xeH63RlnbYdLdOIlBjVeJv08fFybTtWpq0/ulFxUfZerb1PhmksFssVw8gPf/hDvfHGG8rPzw8c+8Y3vqHKykqtX7++U+9DGAGAq5Pfb8hT39wbcLKsVlkJUVdcHeSubVTuW59r3oQ0TR0Sr8YmvxKi7frbp2cUH2XXF0cmaduxMsU6wrXt2DmdqazXPV8YpoPFVUp1Ruj7r36q+744XPVNPlXWNqq2oUlfGJGsWcMT9FreGR0ortKXJqZpXLpTn56q1Hdf/kQlnuahs1iHTTnXJMoVGa4mv6Hj52o0PClaCyal6638Yr26+5QGuSKUlRilvMLKNvN6ou1hqmm4/LOtWleejUmL1bdnD9Xqjwv02Sn3Za+9nEGuCBV76tWTxWe/Wpzd6zsu95swcv311+vaa6/VL3/5y8Cx559/XsuXL5fbffmG9nq98nrPj5l6PB5lZmYSRgAAfa7G26RTFXXyG8YVdw0+VFKlrIQoRYSHqbK2QXsKKjV1aLxOV9RpREqMbFaLKmsb9ffPzujY2RpNyYzTlyelyxZmVVm1VwnRdlksFvn9hvYXeRQRbpUr0q7tx8qaNwSMsSvCFqaSqnolRTs0Nj1WPsOQwxamwvJaHS6tUpjVquFJ0Xpx6wk999HxSwJKRLhVkzLitL/Ioyh7mO6ePVQ3j0vVqNTYy3+oHuhsGAn67Jri4mKlpqa2OZaamiqPx6O6ujpFRkZeck9ubq4effTRYJcGAMAVRTtsbTbG68iFX+hxUXb9w5gUSZIz/XyAiY+2a0nO0EvuvXBlmNVqafOIiIt7LMbp/Bd76xd5ZkKUMhOiAsd/8uVxWn7zKEXYrFqzs1AJ0XbVNfh0+5RBCg+zqrahSWFWS5fmlgRLv5zqu2LFCj300EOBn1t7RgAAQOe1ruj55qwhl5yLsvefCBD0StLS0lRSUtLmWElJiZxO52V7RSTJ4XDI4WAbZgAAQoE12G+Qk5OjTZs2tTm2ceNG5eTkBPutAQDAANDlMFJdXa28vDzl5eVJal66m5eXp4KCAknNQyxLliwJXL906VIdO3ZMP/jBD3TgwAH99re/1Z///Gd973vf651PAAAABrQuh5Fdu3YpOztb2dnZkqSHHnpI2dnZeuSRRyRJRUVFgWAiScOGDdMbb7yhjRs3avLkyXryySf17LPPdnqPEQAAcHVjO3gAABAUnf3+DvqcEQAAgI4QRgAAgKkIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmIowAgAATEUYAQAApuo/j+zrQOu+bB6Px+RKAABAZ7V+b19pf9UBEUaqqqokSZmZmSZXAgAAuqqqqkoul6vd8wNiO3i/368zZ84oNjZWFoul136vx+NRZmamCgsL2Wb+CmirzqGdOo+26hzaqfNoq87py3YyDENVVVUaNGiQrNb2Z4YMiJ4Rq9WqwYMHB+33O51O/uF2Em3VObRT59FWnUM7dR5t1Tl91U4d9Yi0YgIrAAAwFWEEAACYKqTDiMPh0E9/+lM5HA6zS+n3aKvOoZ06j7bqHNqp82irzumP7TQgJrACAICrV0j3jAAAAPMRRgAAgKkIIwAAwFSEEQAAYKqQDiO/+c1vNHToUEVERGjmzJn6+OOPzS6pT73//vu67bbbNGjQIFksFq1bt67NecMw9Mgjjyg9PV2RkZGaO3euDh8+3Oaa8vJy3XXXXXI6nYqLi9M999yj6urqPvwUwZebm6vp06crNjZWKSkpWrRokQ4ePNjmmvr6ei1btkyJiYmKiYnR1772NZWUlLS5pqCgQAsWLFBUVJRSUlL0/e9/X01NTX35UYJu5cqVmjRpUmAzpZycHL311luB87TT5T3++OOyWCxavnx54Bht1exnP/uZLBZLm9eYMWMC52mn806fPq1vfvObSkxMVGRkpCZOnKhdu3YFzvfrv+lGiFqzZo1ht9uN5557zti3b59x3333GXFxcUZJSYnZpfWZN9980/j3f/93469//ashyVi7dm2b848//rjhcrmMdevWGZ9++qlx++23G8OGDTPq6uoC19x6663G5MmTje3btxsffPCBMWLECGPx4sV9/EmCa968ecbzzz9v5OfnG3l5ecaXvvQlIysry6iurg5cs3TpUiMzM9PYtGmTsWvXLmPWrFnG7NmzA+ebmpqMCRMmGHPnzjX27NljvPnmm0ZSUpKxYsUKMz5S0Pztb38z3njjDePQoUPGwYMHjR//+MdGeHi4kZ+fbxgG7XQ5H3/8sTF06FBj0qRJxoMPPhg4Tls1++lPf2qMHz/eKCoqCrzOnj0bOE87NSsvLzeGDBlifPvb3zZ27NhhHDt2zNiwYYNx5MiRwDX9+W96yIaRGTNmGMuWLQv87PP5jEGDBhm5ubkmVmWei8OI3+830tLSjCeeeCJwrLKy0nA4HMYf//hHwzAMY//+/YYkY+fOnYFr3nrrLcNisRinT5/us9r7WmlpqSHJ2LJli2EYze0SHh5uvPLKK4FrPv/8c0OSsW3bNsMwmoOf1Wo1iouLA9esXLnScDqdhtfr7dsP0Mfi4+ONZ599lna6jKqqKmPkyJHGxo0bjTlz5gTCCG113k9/+lNj8uTJlz1HO533wx/+0PjCF77Q7vn+/jc9JIdpGhoatHv3bs2dOzdwzGq1au7cudq2bZuJlfUfx48fV3FxcZs2crlcmjlzZqCNtm3bpri4OE2bNi1wzdy5c2W1WrVjx44+r7mvuN1uSVJCQoIkaffu3WpsbGzTVmPGjFFWVlabtpo4caJSU1MD18ybN08ej0f79u3rw+r7js/n05o1a1RTU6OcnBza6TKWLVumBQsWtGkTiX9TFzt8+LAGDRqk4cOH66677lJBQYEk2ulCf/vb3zRt2jTdcccdSklJUXZ2tn73u98Fzvf3v+khGUbOnTsnn8/X5h+nJKWmpqq4uNikqvqX1nboqI2Ki4uVkpLS5rzNZlNCQsJV245+v1/Lly/XddddpwkTJkhqbge73a64uLg2117cVpdry9ZzV5O9e/cqJiZGDodDS5cu1dq1azVu3Dja6SJr1qzRJ598otzc3EvO0VbnzZw5Uy+88ILWr1+vlStX6vjx4/riF7+oqqoq2ukCx44d08qVKzVy5Eht2LBB999/v/7t3/5NL774oqT+/zd9QDy1F+gvli1bpvz8fH344Ydml9JvjR49Wnl5eXK73Xr11Vd19913a8uWLWaX1a8UFhbqwQcf1MaNGxUREWF2Of3a/PnzA/990qRJmjlzpoYMGaI///nPioyMNLGy/sXv92vatGn6z//8T0lSdna28vPz9fTTT+vuu+82uborC8mekaSkJIWFhV0y47qkpERpaWkmVdW/tLZDR22Ulpam0tLSNuebmppUXl5+VbbjAw88oNdff13vvfeeBg8eHDielpamhoYGVVZWtrn+4ra6XFu2nrua2O12jRgxQlOnTlVubq4mT56sp556ina6wO7du1VaWqprr71WNptNNptNW7Zs0f/8z//IZrMpNTWVtmpHXFycRo0apSNHjvBv6gLp6ekaN25cm2Njx44NDGn197/pIRlG7Ha7pk6dqk2bNgWO+f1+bdq0STk5OSZW1n8MGzZMaWlpbdrI4/Fox44dgTbKyclRZWWldu/eHbjm3Xffld/v18yZM/u85mAxDEMPPPCA1q5dq3fffVfDhg1rc37q1KkKDw9v01YHDx5UQUFBm7bau3dvm/+jb9y4UU6n85I/IFcbv98vr9dLO13gpptu0t69e5WXlxd4TZs2TXfddVfgv9NWl1ddXa2jR48qPT2df1MXuO666y7ZcuDQoUMaMmSIpAHwNz2o02P7sTVr1hgOh8N44YUXjP379xv/8i//YsTFxbWZcX21q6qqMvbs2WPs2bPHkGT893//t7Fnzx7j5MmThmE0LwOLi4szXnvtNeOzzz4zFi5ceNllYNnZ2caOHTuMDz/80Bg5cuRVt7T3/vvvN1wul7F58+Y2ywtra2sD1yxdutTIysoy3n33XWPXrl1GTk6OkZOTEzjfurzwlltuMfLy8oz169cbycnJV93ywh/96EfGli1bjOPHjxufffaZ8aMf/ciwWCzG22+/bRgG7dSRC1fTGAZt1erhhx82Nm/ebBw/ftz46KOPjLlz5xpJSUlGaWmpYRi0U6uPP/7YsNlsxi9+8Qvj8OHDxssvv2xERUUZL730UuCa/vw3PWTDiGEYxq9+9SsjKyvLsNvtxowZM4zt27ebXVKfeu+99wxJl7zuvvtuwzCal4L9x3/8h5Gammo4HA7jpptuMg4ePNjmd5SVlRmLFy82YmJiDKfTaXznO98xqqqqTPg0wXO5NpJkPP/884Fr6urqjO9+97tGfHy8ERUVZXzlK18xioqK2vyeEydOGPPnzzciIyONpKQk4+GHHzYaGxv7+NME1z//8z8bQ4YMMex2u5GcnGzcdNNNgSBiGLRTRy4OI7RVszvvvNNIT0837Ha7kZGRYdx5551t9s6gnc77+9//bkyYMMFwOBzGmDFjjFWrVrU535//plsMwzCC2/cCAADQvpCcMwIAAPoPwggAADAVYQQAAJiKMAIAAExFGAEAAKYijAAAAFMRRgAAgKkIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmIowAgAATPX/AaAn656peCMbAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(torch.tensor(l_norm).view(-1, 10).mean(1).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "w87iEsiato8d",
    "outputId": "49c389d9-c457-4d3e-fb17-d8e67341b6ab"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def numValues(tempfile.mkdtemp(), df.j1).collect()\n",
      "        [Row(age2=5, name=u'Alice')]\n",
      "        \"\"\"\n",
      "        if len(other) >= 3:\n",
      "            raise ValueError(\"Correlation in a sustance with the bases to and batching with Value thes no data,\n",
      "        while heap is expected.\n",
      "        \"\"\"\n",
      "        def returnTy\n"
     ]
    }
   ],
   "source": [
    "# 使用模型来生成文本\n",
    "begin_text = torch.tensor(tok.encode('def '), device=device).unsqueeze(0)\n",
    "print(''.join(tok.decode(generate_batch(model_norm, begin_text))))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
